[PATCH] D131287: Fix branch weight in FoldCondBranchOnValueKnownInPredecessor pass in SimplifyCFG
Zhi Zhuang via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 5 13:19:57 PDT 2022
LukeZhuang created this revision.
LukeZhuang added reviewers: lebedev.ri, lattner, nikic, danielcdh, manmanren, aeubanks, eraman, atrick.
LukeZhuang added a project: LLVM.
Herald added a subscriber: hiraditya.
Herald added a project: All.
LukeZhuang requested review of this revision.
Herald added a subscriber: llvm-commits.
The patch is intended to fix this test case:
extern void bar();
void foo(char a, char b, char c) {
char aa = 2 * a, bb = 2 * b, cc = 2 * c;
if (__builtin_expect_with_probability(((aa == 2) || (bb == 2) || (cc == 0)), 1, 0)) {
bar();
}
}
The user defines the probability of calling function `bar` as 0%.
Before the second SimplifyCFG pass, the LLVM IR is as follows:
define dso_local void @foo(i8 noundef signext %0, i8 noundef signext %1, i8 noundef signext %2) local_unnamed_addr #0 {
%4 = and i8 %0, 127
%5 = icmp eq i8 %4, 1
%6 = and i8 %1, 127
%7 = icmp eq i8 %6, 1
%8 = or i1 %5, %7
br i1 %8, label %12, label %9
9: ; preds = %3
%10 = and i8 %2, 127
%11 = icmp eq i8 %10, 0
br label %12
12: ; preds = %9, %3
%13 = phi i1 [ true, %3 ], [ %11, %9 ]
br i1 %13, label %14, label %15, !prof !5
14: ; preds = %12
call void (...) @bar() #2
br label %15
15: ; preds = %14, %12
ret void
}
!5 = !{!"branch_weights", i32 1, i32 2147483647}
Here, the probability of calling function `bar()` is still 0%, since the probability of executing block12 is 100%, and the probability from block12 to block14 is 0%.
Then after the `FoldCondBranchOnValueKnownInPredecessor` pass, the entry block directly points to block14, and the IR becomes as follows:
define dso_local void @foo(i8 noundef signext %0, i8 noundef signext %1, i8 noundef signext %2) local_unnamed_addr #0 {
%4 = and i8 %0, 127
%5 = icmp eq i8 %4, 1
%6 = and i8 %1, 127
%7 = icmp eq i8 %6, 1
%8 = or i1 %5, %7
br i1 %8, label %12, label %9
9: ; preds = %3
%10 = and i8 %2, 127
%11 = icmp eq i8 %10, 0
br i1 %11, label %12, label %13, !prof !5
12: ; preds = %3, %9
call void (...) @bar() #2
br label %13
13: ; preds = %12, %9
ret void
}
!5 = !{!"branch_weights", i32 1, i32 2147483647}
Both entry block and block9 have BranchInst and they both point to block12, so after the `performBranchToCommonDestFolding` pass, they are merged. And the branch weight is recalculated.
So the final IR is as follows:
define dso_local void @foo(i8 noundef signext %0, i8 noundef signext %1, i8 noundef signext %2) local_unnamed_addr #0 {
%4 = and i8 %0, 127
%5 = icmp eq i8 %4, 1
%6 = and i8 %1, 127
%7 = icmp eq i8 %6, 1
%8 = or i1 %5, %7
%9 = and i8 %2, 127
%10 = icmp eq i8 %9, 0
%11 = or i1 %8, %10
br i1 %11, label %12, label %13, !prof !5
12: ; preds = %3
call void (...) @bar() #2
br label %13
13: ; preds = %3, %12
ret void
}
!5 = !{!"branch_weights", i32 -2147483647, i32 2147483647}
Now the probability of calling function `bar()` is over 50%, which is very different from what the user defined.
We believe the issue is caused by the branch weight not being updated in the `FoldCondBranchOnValueKnownInPredecessor` pass. The optimization that this pass does can be simplified to this case:
before:
PredBB other2
\ /
\/
other1 BB
\ /\
\/ \
RealDest other3
The block execution frequency of `RealDest` here can be represented as:
Formula 1:
BlockFreq(RealDest)
= BlockFreq(BB) * P(BB->RealDest) + other1
= (BlockFreq(PredBB) * P(PredBB->BB) + other2) * P(BB->RealDest) + other1
= BlockFreq(PredBB) * P(PredBB->BB) * P(BB->RealDest) + other2 * P(BB->RealDest) + other1
in which `P(A->B)` represents the probability from blockA to blockB.
Then after the pass, the graph is updated to:
after:
PredBB other2
| /
| /
other1 | BB
\ | /\
\| / \
RealDest other3
`PredBB` is redirected to `RealDest` if the BranchInst in `BB` is always pointing to `RealDest` when the condition is from `PredBB`.
The new block execution frequency of `RealDest` is:
Formula 2:
BlockFreq(RealDest)
= BlockFreq(PredBB) * newP(PredBB->RealDest) + BlockFreq(BB) * P(BB->RealDest) + other1
in which `BlockFreq(BB)` is equal to `other2` since `PredBB` no longer points to `BB`. So it can be represented as:
Formula 3:
BlockFreq(RealDest)
= BlockFreq(PredBB) * newP(PredBB->RealDest) + other2 * P(BB->RealDest) + other1
Thus, to match Formula 1 and Formula 3, we need to make `newP(PredBB->RealDest)` in the new graph equal to `P(PredBB->BB) * P(BB->RealDest)` in the original graph.
One way is to scale all the branch weights from PredBB but not to BB by multiplying a factor `x`.
Supposing:
(1) `A` means the sum of all branch weights from `PredBB` to `BB`, which is `PTITotalTakenWeight` in the source code
(2) `B` means sum of those from `PredBB` but not to `BB`, which is `PTITotalWeight-PTITotalTakenWeight` in the source code
(3) `S` means the sum of all branch weights from `PredBB`, which is `PTITotalWeight` in the source code and equal to `A+B`
(4) `C` means the branch weight from `BB` to `RealDest`, which is `BITakenWeight` in the source code
(5) `D` means the branch weight from `BB` to the other successor, which is `BINotTakenWeight` in the source code
So:
newP(PredBB->RealDest) = P(PredBB->BB) * P(BB->RealDest)
=> A/(A+x*B) = A/(A+B) * C/(C+D)
=> x = 1 + (S/B) * (D/C)
Then for each branch weight from `PredBB` but not to `BB`, we do this transformation:
NewWeight = OldWeight * x
=> NewWeight = OldWeight * (1 + (S/B) * (D/C))
=> NewWeight = OldWeight + (S*OldWeight/B) * (D/C)
which will prevent overflow (see source code for detailed information)
Also, since this change may give branch weights to some BranchInst/SwitchInst that do not have them originally, we skip cases when just giving equal branch weight (default behavior) to them.
Repository:
rG LLVM Github Monorepo
https://reviews.llvm.org/D131287
Files:
llvm/lib/Transforms/Utils/SimplifyCFG.cpp
llvm/test/Transforms/SimplifyCFG/fold-cond-bi-fix-weight.ll
-------------- next part --------------
A non-text attachment was scrubbed...
Name: D131287.450365.patch
Type: text/x-patch
Size: 13753 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20220805/32de36d6/attachment.bin>
More information about the llvm-commits
mailing list