[llvm] r215135 - [Branch probability] Recompute branch weights of tail-merged basic blocks.

Akira Hatanaka ahatanaka at apple.com
Thu Aug 7 12:30:13 PDT 2014


Author: ahatanak
Date: Thu Aug  7 14:30:13 2014
New Revision: 215135

URL: http://llvm.org/viewvc/llvm-project?rev=215135&view=rev
Log:
[Branch probability] Recompute branch weights of tail-merged basic blocks.

BranchFolderPass was not correctly setting the basic block branch weights when
tail-merging created or merged blocks. This patch recomutes the weights of
tail-merged blocks using the following formula:

branch_weight(merged block to successor j) =
sum(block_frequency(bb) * branch_probability(bb -> j))

bb is a block that is in the set of merged blocks.

<rdar://problem/16256423>

Added:
    llvm/trunk/test/CodeGen/ARM/tail-merge-branch-weight.ll
Modified:
    llvm/trunk/lib/CodeGen/BranchFolding.cpp
    llvm/trunk/lib/CodeGen/BranchFolding.h
    llvm/trunk/lib/CodeGen/IfConversion.cpp

Modified: llvm/trunk/lib/CodeGen/BranchFolding.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/BranchFolding.cpp?rev=215135&r1=215134&r2=215135&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/BranchFolding.cpp (original)
+++ llvm/trunk/lib/CodeGen/BranchFolding.cpp Thu Aug  7 14:30:13 2014
@@ -20,6 +20,8 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
+#include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
@@ -71,6 +73,8 @@ namespace {
     bool runOnMachineFunction(MachineFunction &MF) override;
 
     void getAnalysisUsage(AnalysisUsage &AU) const override {
+      AU.addRequired<MachineBlockFrequencyInfo>();
+      AU.addRequired<MachineBranchProbabilityInfo>();
       AU.addRequired<TargetPassConfig>();
       MachineFunctionPass::getAnalysisUsage(AU);
     }
@@ -92,21 +96,24 @@ bool BranchFolderPass::runOnMachineFunct
   // HW that requires structurized CFG.
   bool EnableTailMerge = !MF.getTarget().requiresStructuredCFG() &&
       PassConfig->getEnableTailMerge();
-  BranchFolder Folder(EnableTailMerge, /*CommonHoist=*/true);
+  BranchFolder Folder(EnableTailMerge, /*CommonHoist=*/true,
+                      getAnalysis<MachineBlockFrequencyInfo>(),
+                      getAnalysis<MachineBranchProbabilityInfo>());
   return Folder.OptimizeFunction(MF, MF.getSubtarget().getInstrInfo(),
                                  MF.getSubtarget().getRegisterInfo(),
                                  getAnalysisIfAvailable<MachineModuleInfo>());
 }
 
-
-BranchFolder::BranchFolder(bool defaultEnableTailMerge, bool CommonHoist) {
+BranchFolder::BranchFolder(bool defaultEnableTailMerge, bool CommonHoist,
+                           const MachineBlockFrequencyInfo &FreqInfo,
+                           const MachineBranchProbabilityInfo &ProbInfo)
+    : EnableHoistCommonCode(CommonHoist), MBBFreqInfo(FreqInfo),
+      MBPI(ProbInfo) {
   switch (FlagEnableTailMerge) {
   case cl::BOU_UNSET: EnableTailMerge = defaultEnableTailMerge; break;
   case cl::BOU_TRUE: EnableTailMerge = true; break;
   case cl::BOU_FALSE: EnableTailMerge = false; break;
   }
-
-  EnableHoistCommonCode = CommonHoist;
 }
 
 /// RemoveDeadBlock - Remove the specified dead machine basic block from the
@@ -433,6 +440,9 @@ MachineBasicBlock *BranchFolder::SplitMB
   // Splice the code over.
   NewMBB->splice(NewMBB->end(), &CurMBB, BBI1, CurMBB.end());
 
+  // NewMBB inherits CurMBB's block frequency.
+  MBBFreqInfo.setBlockFreq(NewMBB, MBBFreqInfo.getBlockFreq(&CurMBB));
+
   // For targets that use the register scavenger, we must maintain LiveIns.
   MaintainLiveIns(&CurMBB, NewMBB);
 
@@ -502,6 +512,21 @@ BranchFolder::MergePotentialsElt::operat
 #endif
 }
 
+BlockFrequency
+BranchFolder::MBFIWrapper::getBlockFreq(const MachineBasicBlock *MBB) const {
+  auto I = MergedBBFreq.find(MBB);
+
+  if (I != MergedBBFreq.end())
+    return I->second;
+
+  return MBFI.getBlockFreq(MBB);
+}
+
+void BranchFolder::MBFIWrapper::setBlockFreq(const MachineBasicBlock *MBB,
+                                             BlockFrequency F) {
+  MergedBBFreq[MBB] = F;
+}
+
 /// CountTerminators - Count the number of terminators in the given
 /// block and set I to the position of the first non-terminator, if there
 /// is one, or MBB->end() otherwise.
@@ -804,6 +829,10 @@ bool BranchFolder::TryTailMergeBlocks(Ma
     }
 
     MachineBasicBlock *MBB = SameTails[commonTailIndex].getBlock();
+
+    // Recompute commont tail MBB's edge weights and block frequency.
+    setCommonTailEdgeWeights(*MBB);
+
     // MBB is common tail.  Adjust all other BB's to jump to this one.
     // Traversal must be forwards so erases work.
     DEBUG(dbgs() << "\nUsing common tail in BB#" << MBB->getNumber()
@@ -966,6 +995,44 @@ bool BranchFolder::TailMergeBlocks(Machi
   return MadeChange;
 }
 
+void BranchFolder::setCommonTailEdgeWeights(MachineBasicBlock &TailMBB) {
+  SmallVector<BlockFrequency, 2> EdgeFreqLs(TailMBB.succ_size());
+  BlockFrequency AccumulatedMBBFreq;
+
+  // Aggregate edge frequency of successor edge j:
+  //  edgeFreq(j) = sum (freq(bb) * edgeProb(bb, j)),
+  //  where bb is a basic block that is in SameTails.
+  for (const auto &Src : SameTails) {
+    const MachineBasicBlock *SrcMBB = Src.getBlock();
+    BlockFrequency BlockFreq = MBBFreqInfo.getBlockFreq(SrcMBB);
+    AccumulatedMBBFreq += BlockFreq;
+
+    // It is not necessary to recompute edge weights if TailBB has less than two
+    // successors.
+    if (TailMBB.succ_size() <= 1)
+      continue;
+
+    auto EdgeFreq = EdgeFreqLs.begin();
+
+    for (auto SuccI = TailMBB.succ_begin(), SuccE = TailMBB.succ_end();
+         SuccI != SuccE; ++SuccI, ++EdgeFreq)
+      *EdgeFreq += BlockFreq * MBPI.getEdgeProbability(SrcMBB, *SuccI);
+  }
+
+  MBBFreqInfo.setBlockFreq(&TailMBB, AccumulatedMBBFreq);
+
+  if (TailMBB.succ_size() <= 1)
+    return;
+
+  auto MaxEdgeFreq = *std::max_element(EdgeFreqLs.begin(), EdgeFreqLs.end());
+  uint64_t Scale = MaxEdgeFreq.getFrequency() / UINT32_MAX + 1;
+  auto EdgeFreq = EdgeFreqLs.begin();
+
+  for (auto SuccI = TailMBB.succ_begin(), SuccE = TailMBB.succ_end();
+       SuccI != SuccE; ++SuccI, ++EdgeFreq)
+    TailMBB.setSuccWeight(SuccI, EdgeFreq->getFrequency() / Scale);
+}
+
 //===----------------------------------------------------------------------===//
 //  Branch Optimization
 //===----------------------------------------------------------------------===//

Modified: llvm/trunk/lib/CodeGen/BranchFolding.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/BranchFolding.h?rev=215135&r1=215134&r2=215135&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/BranchFolding.h (original)
+++ llvm/trunk/lib/CodeGen/BranchFolding.h Thu Aug  7 14:30:13 2014
@@ -12,9 +12,12 @@
 
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/Support/BlockFrequency.h"
 #include <vector>
 
 namespace llvm {
+  class MachineBlockFrequencyInfo;
+  class MachineBranchProbabilityInfo;
   class MachineFunction;
   class MachineModuleInfo;
   class RegScavenger;
@@ -23,7 +26,9 @@ namespace llvm {
 
   class BranchFolder {
   public:
-    explicit BranchFolder(bool defaultEnableTailMerge, bool CommonHoist);
+    explicit BranchFolder(bool defaultEnableTailMerge, bool CommonHoist,
+                          const MachineBlockFrequencyInfo &MBFI,
+                          const MachineBranchProbabilityInfo &MBPI);
 
     bool OptimizeFunction(MachineFunction &MF,
                           const TargetInstrInfo *tii,
@@ -92,9 +97,26 @@ namespace llvm {
     MachineModuleInfo *MMI;
     RegScavenger *RS;
 
+    /// \brief This class keeps track of branch frequencies of newly created
+    /// blocks and tail-merged blocks.
+    class MBFIWrapper {
+    public:
+      MBFIWrapper(const MachineBlockFrequencyInfo &I) : MBFI(I) {}
+      BlockFrequency getBlockFreq(const MachineBasicBlock *MBB) const;
+      void setBlockFreq(const MachineBasicBlock *MBB, BlockFrequency F);
+
+    private:
+      const MachineBlockFrequencyInfo &MBFI;
+      DenseMap<const MachineBasicBlock *, BlockFrequency> MergedBBFreq;
+    };
+
+    MBFIWrapper MBBFreqInfo;
+    const MachineBranchProbabilityInfo &MBPI;
+
     bool TailMergeBlocks(MachineFunction &MF);
     bool TryTailMergeBlocks(MachineBasicBlock* SuccBB,
                        MachineBasicBlock* PredBB);
+    void setCommonTailEdgeWeights(MachineBasicBlock &TailMBB);
     void MaintainLiveIns(MachineBasicBlock *CurMBB,
                          MachineBasicBlock *NewMBB);
     void ReplaceTailWithBranchTo(MachineBasicBlock::iterator OldInst,

Modified: llvm/trunk/lib/CodeGen/IfConversion.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/IfConversion.cpp?rev=215135&r1=215134&r2=215135&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/IfConversion.cpp (original)
+++ llvm/trunk/lib/CodeGen/IfConversion.cpp Thu Aug  7 14:30:13 2014
@@ -17,6 +17,7 @@
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/CodeGen/LivePhysRegs.h"
+#include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -161,6 +162,7 @@ namespace {
     const TargetLoweringBase *TLI;
     const TargetInstrInfo *TII;
     const TargetRegisterInfo *TRI;
+    const MachineBlockFrequencyInfo *MBFI;
     const MachineBranchProbabilityInfo *MBPI;
     MachineRegisterInfo *MRI;
 
@@ -177,6 +179,7 @@ namespace {
     }
 
     void getAnalysisUsage(AnalysisUsage &AU) const override {
+      AU.addRequired<MachineBlockFrequencyInfo>();
       AU.addRequired<MachineBranchProbabilityInfo>();
       MachineFunctionPass::getAnalysisUsage(AU);
     }
@@ -272,6 +275,7 @@ bool IfConverter::runOnMachineFunction(M
   TLI = MF.getSubtarget().getTargetLowering();
   TII = MF.getSubtarget().getInstrInfo();
   TRI = MF.getSubtarget().getRegisterInfo();
+  MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
   MBPI = &getAnalysis<MachineBranchProbabilityInfo>();
   MRI = &MF.getRegInfo();
 
@@ -286,7 +290,7 @@ bool IfConverter::runOnMachineFunction(M
   bool BFChange = false;
   if (!PreRegAlloc) {
     // Tail merge tend to expose more if-conversion opportunities.
-    BranchFolder BF(true, false);
+    BranchFolder BF(true, false, *MBFI, *MBPI);
     BFChange = BF.OptimizeFunction(MF, TII, MF.getSubtarget().getRegisterInfo(),
                                    getAnalysisIfAvailable<MachineModuleInfo>());
   }
@@ -419,7 +423,7 @@ bool IfConverter::runOnMachineFunction(M
   BBAnalysis.clear();
 
   if (MadeChange && IfCvtBranchFold) {
-    BranchFolder BF(false, false);
+    BranchFolder BF(false, false, *MBFI, *MBPI);
     BF.OptimizeFunction(MF, TII, MF.getSubtarget().getRegisterInfo(),
                         getAnalysisIfAvailable<MachineModuleInfo>());
   }

Added: llvm/trunk/test/CodeGen/ARM/tail-merge-branch-weight.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/ARM/tail-merge-branch-weight.ll?rev=215135&view=auto
==============================================================================
--- llvm/trunk/test/CodeGen/ARM/tail-merge-branch-weight.ll (added)
+++ llvm/trunk/test/CodeGen/ARM/tail-merge-branch-weight.ll Thu Aug  7 14:30:13 2014
@@ -0,0 +1,44 @@
+; RUN: llc -mtriple=arm-apple-ios -print-machineinstrs=branch-folder \
+; RUN: %s -o /dev/null 2>&1 | FileCheck %s
+
+; Branch probability of tailed-merged block:
+;
+; p(L0_L1 -> L2) = p(entry -> L0) * p(L0 -> L2) + p(entry -> L1) * p(L1 -> L2)
+;                = 0.2 * 0.6 + 0.8 * 0.3 = 0.36
+; p(L0_L1 -> L3) = p(entry -> L0) * p(L0 -> L3) + p(entry -> L1) * p(L1 -> L3)
+;                = 0.2 * 0.4 + 0.8 * 0.7 = 0.64
+
+; CHECK: # Machine code for function test0:
+; CHECK: Successors according to CFG: BB#{{[0-9]+}}(13) BB#{{[0-9]+}}(24)
+; CHECK: BB#{{[0-9]+}}:
+; CHECK: BB#{{[0-9]+}}:
+; CHECK: # End machine code for function test0.
+
+define i32 @test0(i32 %n, i32 %m, i32* nocapture %a, i32* nocapture %b) {
+entry:
+  %cmp = icmp sgt i32 %n, 0
+  br i1 %cmp, label %L0, label %L1, !prof !0
+
+L0:                                          ; preds = %entry
+  store i32 12, i32* %a, align 4
+  store i32 18, i32* %b, align 4
+  %cmp1 = icmp eq i32 %m, 8
+  br i1 %cmp1, label %L2, label %L3, !prof !1
+
+L1:                                          ; preds = %entry
+  store i32 14, i32* %a, align 4
+  store i32 18, i32* %b, align 4
+  %cmp3 = icmp eq i32 %m, 8
+  br i1 %cmp3, label %L2, label %L3, !prof !2
+
+L2:                                               ; preds = %L1, %L0
+  br label %L3
+
+L3:                                           ; preds = %L0, %L1, %L2
+  %retval.0 = phi i32 [ 100, %L2 ], [ 6, %L1 ], [ 6, %L0 ]
+  ret i32 %retval.0
+}
+
+!0 = metadata !{metadata !"branch_weights", i32 200, i32 800}
+!1 = metadata !{metadata !"branch_weights", i32 600, i32 400}
+!2 = metadata !{metadata !"branch_weights", i32 300, i32 700}





More information about the llvm-commits mailing list