[llvm] [SimplifyCFG] Fix uint32_t overflow in cbranch to cbranch merge prevention check. (PR #72329)
Valery Pykhtin via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 14 22:54:17 PST 2023
https://github.com/vpykhtin updated https://github.com/llvm/llvm-project/pull/72329
>From 2cfd9cf7fa1e1a457fc6ddc6a466b1dfdc0e95af Mon Sep 17 00:00:00 2001
From: Valery Pykhtin <valery.pykhtin at gmail.com>
Date: Wed, 15 Nov 2023 01:17:05 +0100
Subject: [PATCH 1/3] test before fix
---
.../SimplifyCFG/branch-cond-dont-merge.ll | 19 +++++++++++++++++++
1 file changed, 19 insertions(+)
diff --git a/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
index 5c21f163826ee25..3e924dc15bbe123 100644
--- a/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
+++ b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
@@ -79,6 +79,25 @@ exit:
ret void
}
+define void @uint32_overflow_test(i1 %arg, i1 %arg1) {
+; CHECK-LABEL: @uint32_overflow_test(
+; CHECK-NEXT: bb:
+; CHECK-NEXT: ret void
+;
+bb:
+ br i1 %arg, label %bb4, label %bb2, !prof !3
+
+bb2:
+ br i1 %arg1, label %bb4, label %bb3
+
+bb3:
+ br label %bb4
+
+bb4:
+ ret void
+}
+
!0 = !{!"branch_weights", i32 1, i32 1000}
!1 = !{!"branch_weights", i32 1000, i32 1}
!2 = !{!"branch_weights", i32 3, i32 2}
+!3 = !{!"branch_weights", i32 -258677585, i32 -1212131848}
>From 5ffc278fcc5af769629a7e1a03f51021cde0412f Mon Sep 17 00:00:00 2001
From: Valery Pykhtin <valery.pykhtin at gmail.com>
Date: Wed, 15 Nov 2023 01:18:39 +0100
Subject: [PATCH 2/3] fix
---
llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 4 ++--
llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll | 1 +
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 6009558efca06af..f15ab6858401fc6 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -4351,10 +4351,10 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
SmallVector<uint32_t, 2> PredWeights;
if (!PBI->getMetadata(LLVMContext::MD_unpredictable) &&
extractBranchWeights(*PBI, PredWeights) &&
- (PredWeights[0] + PredWeights[1]) != 0) {
+ ((uint64_t)PredWeights[0] + PredWeights[1]) != 0) {
BranchProbability CommonDestProb = BranchProbability::getBranchProbability(
- PredWeights[PBIOp], PredWeights[0] + PredWeights[1]);
+ PredWeights[PBIOp], (uint64_t)PredWeights[0] + PredWeights[1]);
BranchProbability Likely = TTI.getPredictableBranchThreshold();
if (CommonDestProb >= Likely)
diff --git a/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
index 3e924dc15bbe123..0f540007b242e29 100644
--- a/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
+++ b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
@@ -82,6 +82,7 @@ exit:
define void @uint32_overflow_test(i1 %arg, i1 %arg1) {
; CHECK-LABEL: @uint32_overflow_test(
; CHECK-NEXT: bb:
+; CHECK-NEXT: [[BRMERGE:%.*]] = select i1 [[ARG:%.*]], i1 true, i1 [[ARG1:%.*]]
; CHECK-NEXT: ret void
;
bb:
>From fea31bc3922883a81a1d0be0da5521c391491eb2 Mon Sep 17 00:00:00 2001
From: Valery Pykhtin <valery.pykhtin at gmail.com>
Date: Wed, 15 Nov 2023 07:53:45 +0100
Subject: [PATCH 3/3] use static_cast
---
llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index f15ab6858401fc6..299a143f30fe211 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -4351,10 +4351,11 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
SmallVector<uint32_t, 2> PredWeights;
if (!PBI->getMetadata(LLVMContext::MD_unpredictable) &&
extractBranchWeights(*PBI, PredWeights) &&
- ((uint64_t)PredWeights[0] + PredWeights[1]) != 0) {
+ (static_cast<uint64_t>(PredWeights[0]) + PredWeights[1]) != 0) {
BranchProbability CommonDestProb = BranchProbability::getBranchProbability(
- PredWeights[PBIOp], (uint64_t)PredWeights[0] + PredWeights[1]);
+ PredWeights[PBIOp],
+ static_cast<uint64_t>(PredWeights[0]) + PredWeights[1]);
BranchProbability Likely = TTI.getPredictableBranchThreshold();
if (CommonDestProb >= Likely)
More information about the llvm-commits
mailing list