[PATCH] D131287: Fix branch weight in FoldCondBranchOnValueKnownInPredecessor pass in SimplifyCFG

Zhi Zhuang via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 5 13:19:57 PDT 2022

LukeZhuang created this revision.
LukeZhuang added reviewers: lebedev.ri, lattner, nikic, danielcdh, manmanren, aeubanks, eraman, atrick.
LukeZhuang added a project: LLVM.
Herald added a subscriber: hiraditya.
Herald added a project: All.
LukeZhuang requested review of this revision.
Herald added a subscriber: llvm-commits.

The patch is intended to fix this test case:

  extern void bar();
  void foo(char a, char b, char c) {
    char aa = 2 * a, bb = 2 * b, cc = 2 * c;
    if (__builtin_expect_with_probability(((aa == 2) || (bb == 2) || (cc == 0)), 1, 0)) {

The user defines the probability of calling function `bar` as 0%.

Before the second SimplifyCFG pass, the LLVM IR is as follows:

  define dso_local void @foo(i8 noundef signext %0, i8 noundef signext %1, i8 noundef signext %2) local_unnamed_addr #0 {
    %4 = and i8 %0, 127
    %5 = icmp eq i8 %4, 1
    %6 = and i8 %1, 127
    %7 = icmp eq i8 %6, 1
    %8 = or i1 %5, %7
    br i1 %8, label %12, label %9
  9:                                                ; preds = %3
    %10 = and i8 %2, 127
    %11 = icmp eq i8 %10, 0
    br label %12
  12:                                               ; preds = %9, %3
    %13 = phi i1 [ true, %3 ], [ %11, %9 ]
    br i1 %13, label %14, label %15, !prof !5
  14:                                               ; preds = %12
    call void (...) @bar() #2
    br label %15
  15:                                               ; preds = %14, %12
    ret void
  !5 = !{!"branch_weights", i32 1, i32 2147483647}

Here, the probability of calling function `bar()` is still 0%, since the probability of executing block12 is 100%, and the probability from block12 to block14 is 0%.

Then after the `FoldCondBranchOnValueKnownInPredecessor` pass, the entry block directly points to block14, and the IR becomes as follows:

  define dso_local void @foo(i8 noundef signext %0, i8 noundef signext %1, i8 noundef signext %2) local_unnamed_addr #0 {
    %4 = and i8 %0, 127
    %5 = icmp eq i8 %4, 1
    %6 = and i8 %1, 127
    %7 = icmp eq i8 %6, 1
    %8 = or i1 %5, %7
    br i1 %8, label %12, label %9
  9:                                                ; preds = %3
    %10 = and i8 %2, 127
    %11 = icmp eq i8 %10, 0
    br i1 %11, label %12, label %13, !prof !5
  12:                                               ; preds = %3, %9
    call void (...) @bar() #2
    br label %13
  13:                                               ; preds = %12, %9
    ret void
  !5 = !{!"branch_weights", i32 1, i32 2147483647}

Both entry block and block9 have BranchInst and they both point to block12, so after the `performBranchToCommonDestFolding` pass, they are merged. And the branch weight is recalculated.

So the final IR is as follows:

  define dso_local void @foo(i8 noundef signext %0, i8 noundef signext %1, i8 noundef signext %2) local_unnamed_addr #0 {
    %4 = and i8 %0, 127
    %5 = icmp eq i8 %4, 1
    %6 = and i8 %1, 127
    %7 = icmp eq i8 %6, 1
    %8 = or i1 %5, %7
    %9 = and i8 %2, 127
    %10 = icmp eq i8 %9, 0
    %11 = or i1 %8, %10
    br i1 %11, label %12, label %13, !prof !5
  12:                                               ; preds = %3
    call void (...) @bar() #2
    br label %13
  13:                                               ; preds = %3, %12
    ret void
  !5 = !{!"branch_weights", i32 -2147483647, i32 2147483647}

Now the probability of calling function `bar()` is over 50%, which is very different from what the user defined.

We believe the issue is caused by the branch weight not being updated in the `FoldCondBranchOnValueKnownInPredecessor` pass. The optimization that this pass does can be simplified to this case:

    PredBB  other2
         \  /
  other1  BB
       \  /\
        \/  \
   RealDest  other3

The block execution frequency of `RealDest` here can be represented as:

  Formula 1:
  = BlockFreq(BB) * P(BB->RealDest) + other1
  = (BlockFreq(PredBB) * P(PredBB->BB) + other2) * P(BB->RealDest) + other1
  = BlockFreq(PredBB) * P(PredBB->BB) * P(BB->RealDest) + other2 * P(BB->RealDest) + other1

in which `P(A->B)` represents the probability from blockA to blockB.

Then after the pass, the graph is updated to:

    PredBB   other2
         |     /
         |    /
  other1 |   BB
       \ |   /\
        \|  /  \
     RealDest    other3

`PredBB` is redirected to `RealDest` if the BranchInst in `BB` is always pointing to `RealDest` when the condition is from `PredBB`.
The new block execution frequency of `RealDest` is:

  Formula 2:
  = BlockFreq(PredBB) * newP(PredBB->RealDest) + BlockFreq(BB) * P(BB->RealDest) + other1

in which `BlockFreq(BB)` is equal to `other2` since `PredBB` no longer points to `BB`. So it can be represented as:

  Formula 3:
  = BlockFreq(PredBB) * newP(PredBB->RealDest) + other2 * P(BB->RealDest) + other1

Thus, to match Formula 1 and Formula 3, we need to make `newP(PredBB->RealDest)` in the new graph equal to `P(PredBB->BB) * P(BB->RealDest)` in the original graph.

One way is to scale all the branch weights from PredBB but not to BB by multiplying a factor `x`.
(1) `A` means the sum of all branch weights from `PredBB` to `BB`, which is `PTITotalTakenWeight` in the source code
(2) `B` means sum of those from `PredBB` but not to `BB`, which is `PTITotalWeight-PTITotalTakenWeight` in the source code
(3) `S` means the sum of all branch weights from `PredBB`, which is `PTITotalWeight` in the source code and equal to `A+B`
(4) `C` means the branch weight from `BB` to `RealDest`, which is `BITakenWeight` in the source code
(5) `D` means the branch weight from `BB` to the other successor, which is `BINotTakenWeight` in the source code

     newP(PredBB->RealDest) = P(PredBB->BB) * P(BB->RealDest)
  => A/(A+x*B) = A/(A+B) * C/(C+D)
  => x = 1 + (S/B) * (D/C)

Then for each branch weight from `PredBB` but not to `BB`, we do this transformation:

     NewWeight = OldWeight * x
  => NewWeight = OldWeight * (1 + (S/B) * (D/C))
  => NewWeight = OldWeight + (S*OldWeight/B) * (D/C)

which will prevent overflow (see source code for detailed information)

Also, since this change may give branch weights to some BranchInst/SwitchInst that do not have them originally, we skip cases when just giving equal branch weight (default behavior) to them.

  rG LLVM Github Monorepo



-------------- next part --------------
A non-text attachment was scrubbed...
Name: D131287.450365.patch
Type: text/x-patch
Size: 13753 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20220805/32de36d6/attachment.bin>

More information about the llvm-commits mailing list