[PATCH] D131287: Fix branch weight in FoldCondBranchOnValueKnownInPredecessor pass in SimplifyCFG

Paul Kirth via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 19 13:22:42 PDT 2022


paulkirth added a comment.

Hi, thanks for working on this.

I'm probably not the most qualified reviewer in this area,  but I wanted to supply some feedback. In particular, I've focused on places where I felt like this patch was hard to follow, and have tried to make suggestions that would have aided my understanding.

I also agree w/ @nikic that this looks related to `JumpThreadingPass::updateBlockFreqAndEdgeWeight`. Do you have thoughts on why that is/isn't relevant?

As a side note, you may get better responses if the # of reviewers is smaller. People tend to assume someone else will review a change if there are lots of people in the list. Moreso if //anyone else// responds. That was certainly the case w/ me. :)



================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:879
+// A general interface to set branch_weights for SwitchInst/BranchInst/SelInst
+static void setBranchWeights(Instruction *I, ArrayRef<uint32_t> Weights) {
+  assert(isa<SwitchInst>(I) || isa<BranchInst>(I) || isa<SelectInst>(I));
----------------

I don't think you need this helper function.

the `setBranchWeights(SwitchInst*,...)` version should be able to handle these cases if you just modify the signature to take `Instruction*` instead and add the assert below.

The reason being is that `CreateBranchWeights()` only offers a convenient api for branch/select that takes two weights to avoid writing boilerplate at callstites.  Internally it just calls the version of `CreateBranchWeights()` that takes an `ArrayRef`.

In this case, I think in this case it’s better to follow the same approach and avoid introducing a redundant function implementation. 


================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:1100
+/// Keep halving the weights until the sum of weights can fit in uint32_t.
+static void FitTotalWeights(MutableArrayRef<uint64_t> Weights) {
+  uint64_t Sum = std::accumulate(Weights.begin(), Weights.end(), (uint64_t)0);
----------------
I’m fairly certain we don’t want to scale the weights down so that the sum of all weights fits in a uint32. 

For one there are all kinds of things that assume they can be larger. Off the top of my head, Optimization Remarks are often filtered by hotness (i.e. branch weight). But IIRC there are also transforms that only trigger when the weights exceed some threshold. Those may be better served by using a probability, but my recollection is that they do not. Though on this point someone w/ more comprehensive knowledge in this area should probably weigh in. Still, I'm skeptical that we want to do this.

I'm a bit surprised that scaling weights down like this does not affect any tests. That may be highlighting some untested paths more than that this change in behavior is doing the right thing. 


================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:2998
 
+/// We want to make the new branch probability of PTI taken(from PredBB to BB)
+/// to be original one times the probability of BI taken(from BB to Dest),
----------------
What is PTI? Is this a common term from somewhere? If it's the predecessor, then `Pred` or `PredInst` seems to communicate more.


================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:3000-3001
+/// to be original one times the probability of BI taken(from BB to Dest),
+/// one simple way is to scale all the PTI not_taken branch weight by a factor
+/// X. By calculation, X should be the formula below:
+/// X = 1 + (BNTW/BTW) * (totW/totNTW)
----------------
Can you rephrase this? I'd like this to be clear to all readers, and it's currently a bit hard to parse.


================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:3003-3005
+/// in which "BNTW" is BINotTakenWeight and "BTW" is BITakenWeight, "totW" is
+/// the sum of branch weights of PTI, "totNTW" is the sum of all PTI not_taken
+/// branch weight(which is totW minus sum of PTI taken branches)
----------------
I'd suggest avoiding variable names from the implementation here where you're documenting your algorithm and approach.

Can you also break each variable def/description onto its own line to make reading a bit easier?


================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:3007
+/// The value will not overflow if totW is in uint32_t range(see details below)
+static void ScaleWeights(uint64_t BITakenWeight, uint64_t BINotTakenWeight,
+                         SmallVector<uint64_t, 8> &PTIWeights,
----------------
This function seems to go to a lot of trouble to avoid overflowing `uint32_t`, which is good, but it seems like things would be significantly simpler if you did the calculations with `uint64_t` and then scaled the weights the same way we do elsewhere. If that is common functionality, then it can be moved to elsewhere, like `IR/ProfDataUtils`.

Are there other reasons other reasons, concerns, or context here that I'm missing?


================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:3012
+      std::accumulate(PTIWeights.begin(), PTIWeights.end(), (uint64_t)0);
+  assert(PTITotalWeight <= UINT32_MAX);
+  uint64_t PTITotalTakenWeight = 0;
----------------
I don't think it's correct to assert this, is it? Weights are allowed to be 32-bits, and summing them could overflow `uint32_t`, and is an expected possibility (see `ProfDataUtils.cpp`). So why is this case different?


================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:3017
+
+  for (unsigned I = 0, E = PTIWeights.size(); I != E; I++) {
+    if (PTITakenIndexes.find(I) != PTITakenIndexes.end())
----------------
Can you rename this `Idx`, or something to more clearly indicate this is an index? It's extremely common in our codebase to use `I` for instructions, and in this case I'm finding myself reading the code incorrectly.

This compounded below when you say:
>     // Branch "I" is also one of the PTI not_taken branch and that means




================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:3022
+    // Small adjustment by doing add below,
+    // so newWeights[I] = Weights[I] + (totW*Weights[I]/totNTW) * (BNTW/BTW)
+
----------------
This is significantly easier to follow than the code below. Why not use these names?


================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:3027
+    // Add max(1,w) because nothing prevents weight to be 0
+    uint64_t Inc = PTITotalWeight * PTIWeights[I] /
+                   std::max(1UL, (PTITotalWeight - PTITotalTakenWeight));
----------------
If you need an intermediate variable, please call out what it is in your equations. 

We can all figure it out, but it helps to be either obvious from the naming, or to note in the comment that you're calculating `(totW*Weights[I]/totNTW)`.


================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:3045
+    BasicBlock *PredBB = PredBBs[i];
+    Instruction *PTI = PredBB->getTerminator();
+    // We only need to fix branch weights for BranchInst/SwitchInst
----------------
OK, so PTI is the Predecessor's Terminator Inst... I still think we need to name some of this differently. PTI is probably fine for reffing the terminator itself when you care about it being a terminator. But in most cases here you're really talking about incoming branches and their weights.  PTIWeights doesn't communicate that IMO.




================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:3048-3062
+    if (BranchInst *PBI = dyn_cast<BranchInst>(PTI)) {
+      if (PBI->isUnconditional())
+        continue;
+      uint64_t PBITrue, PBIFalse;
+      bool PBIHasWeights = PBI->extractProfMetadata(PBITrue, PBIFalse);
+      if (!PBIHasWeights)
+        PBITrue = PBIFalse = 1;
----------------
can't you simplify this significantly by treating branch and switch uniformly? We have APIs in ProfDataUtils just for that...


You should be able to avoid most of the dyn_casts and the different conditions almost entirely, right?

https://github.com/llvm/llvm-project/blob/a7441289e2eb8ccf71c1a60b71d898069e82b22b/llvm/lib/IR/ProfDataUtils.cpp#L111

Assigning default weights can also be made uniform.


================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:3052
+      uint64_t PBITrue, PBIFalse;
+      bool PBIHasWeights = PBI->extractProfMetadata(PBITrue, PBIFalse);
+      if (!PBIHasWeights)
----------------
We have APIs for checking if an instruction has branch weight metadata.

see https://github.com/llvm/llvm-project/blob/a7441289e2eb8ccf71c1a60b71d898069e82b22b/llvm/lib/IR/ProfDataUtils.cpp#L99



================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:3161-3164
+    //   BlockFreq(RealDest)
+    // = BlockFreq(BB) * P(BB->RealDest) + eps1
+    // = (BlockFreq(PredBB) * P(PredBB->BB) + eps2) * P(BB->RealDest) + eps1
+    // = BlockFreq(PredBB) * P(PredBB->BB) * P(BB->RealDest) + eps
----------------



================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:3161-3164
+    //   BlockFreq(RealDest)
+    // = BlockFreq(BB) * P(BB->RealDest) + eps1
+    // = (BlockFreq(PredBB) * P(PredBB->BB) + eps2) * P(BB->RealDest) + eps1
+    // = BlockFreq(PredBB) * P(PredBB->BB) * P(BB->RealDest) + eps
----------------
paulkirth wrote:
> 



================
Comment at: llvm/lib/Transforms/Utils/SimplifyCFG.cpp:3171-3172
+    if (BIHasWeights) {
+      if (!CB->getZExtValue())
+        std::swap(BITakenWeight, BINotTakenWeight);
+      FixupBranchWeight(EdgeBB, PredBBs, BITakenWeight, BINotTakenWeight);
----------------
Are we inverting the branch condition? why are we swapping? I think I've missed something here...


================
Comment at: llvm/test/Transforms/SimplifyCFG/fold-cond-bi-fix-weight.ll:12-13
+;; Actual probability of hitting foo is 1/10(10%)
+;; Before fix, the probability of hitting foo is 11/20(55%)
+;; After fix, the probability of hitting foo is 29/200(14.5%)
+define void @branchinst_noweight(i8 %a0, i8 %a1, i8 %a2) {
----------------
I don't know if these comments(here and elsewhere) make sense outside of this patch, since 'Before' and 'After' are only relevant in the context of your change...

I think if you want to more thoroughly explain the purpose of the test, you'll need to provide context. A summary of the incorrect behavior and what a regression looks like is probably enough. You can also ref this patch, and anyone coming along later can easily understand why this test exists, and what it's checking.

WDYT?


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D131287/new/

https://reviews.llvm.org/D131287



More information about the llvm-commits mailing list