[llvm] [SimplifyCFG] Fix uint32_t overflow in cbranch to cbranch merge prevention check. (PR #72329)

Valery Pykhtin via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 14 22:54:17 PST 2023


https://github.com/vpykhtin updated https://github.com/llvm/llvm-project/pull/72329

>From 2cfd9cf7fa1e1a457fc6ddc6a466b1dfdc0e95af Mon Sep 17 00:00:00 2001
From: Valery Pykhtin <valery.pykhtin at gmail.com>
Date: Wed, 15 Nov 2023 01:17:05 +0100
Subject: [PATCH 1/3] test before fix

---
 .../SimplifyCFG/branch-cond-dont-merge.ll     | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
index 5c21f163826ee25..3e924dc15bbe123 100644
--- a/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
+++ b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
@@ -79,6 +79,25 @@ exit:
   ret void
 }
 
+define void @uint32_overflow_test(i1 %arg, i1 %arg1) {
+; CHECK-LABEL: @uint32_overflow_test(
+; CHECK-NEXT:  bb:
+; CHECK-NEXT:    ret void
+;
+bb:
+  br i1 %arg, label %bb4, label %bb2, !prof !3
+
+bb2:
+  br i1 %arg1, label %bb4, label %bb3
+
+bb3:
+  br label %bb4
+
+bb4:
+  ret void
+}
+
 !0 = !{!"branch_weights", i32 1, i32 1000}
 !1 = !{!"branch_weights", i32 1000, i32 1}
 !2 = !{!"branch_weights", i32 3, i32 2}
+!3 = !{!"branch_weights", i32 -258677585, i32 -1212131848}

>From 5ffc278fcc5af769629a7e1a03f51021cde0412f Mon Sep 17 00:00:00 2001
From: Valery Pykhtin <valery.pykhtin at gmail.com>
Date: Wed, 15 Nov 2023 01:18:39 +0100
Subject: [PATCH 2/3] fix

---
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp                  | 4 ++--
 llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll | 1 +
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 6009558efca06af..f15ab6858401fc6 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -4351,10 +4351,10 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
   SmallVector<uint32_t, 2> PredWeights;
   if (!PBI->getMetadata(LLVMContext::MD_unpredictable) &&
       extractBranchWeights(*PBI, PredWeights) &&
-      (PredWeights[0] + PredWeights[1]) != 0) {
+      ((uint64_t)PredWeights[0] + PredWeights[1]) != 0) {
 
     BranchProbability CommonDestProb = BranchProbability::getBranchProbability(
-        PredWeights[PBIOp], PredWeights[0] + PredWeights[1]);
+        PredWeights[PBIOp], (uint64_t)PredWeights[0] + PredWeights[1]);
 
     BranchProbability Likely = TTI.getPredictableBranchThreshold();
     if (CommonDestProb >= Likely)
diff --git a/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
index 3e924dc15bbe123..0f540007b242e29 100644
--- a/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
+++ b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
@@ -82,6 +82,7 @@ exit:
 define void @uint32_overflow_test(i1 %arg, i1 %arg1) {
 ; CHECK-LABEL: @uint32_overflow_test(
 ; CHECK-NEXT:  bb:
+; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[ARG:%.*]], i1 true, i1 [[ARG1:%.*]]
 ; CHECK-NEXT:    ret void
 ;
 bb:

>From fea31bc3922883a81a1d0be0da5521c391491eb2 Mon Sep 17 00:00:00 2001
From: Valery Pykhtin <valery.pykhtin at gmail.com>
Date: Wed, 15 Nov 2023 07:53:45 +0100
Subject: [PATCH 3/3] use static_cast

---
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index f15ab6858401fc6..299a143f30fe211 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -4351,10 +4351,11 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
   SmallVector<uint32_t, 2> PredWeights;
   if (!PBI->getMetadata(LLVMContext::MD_unpredictable) &&
       extractBranchWeights(*PBI, PredWeights) &&
-      ((uint64_t)PredWeights[0] + PredWeights[1]) != 0) {
+      (static_cast<uint64_t>(PredWeights[0]) + PredWeights[1]) != 0) {
 
     BranchProbability CommonDestProb = BranchProbability::getBranchProbability(
-        PredWeights[PBIOp], (uint64_t)PredWeights[0] + PredWeights[1]);
+        PredWeights[PBIOp],
+        static_cast<uint64_t>(PredWeights[0]) + PredWeights[1]);
 
     BranchProbability Likely = TTI.getPredictableBranchThreshold();
     if (CommonDestProb >= Likely)



More information about the llvm-commits mailing list