[llvm-commits] [llvm] r162868 - /llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp

Andrew Trick atrick at apple.com
Wed Aug 29 14:46:38 PDT 2012


Author: atrick
Date: Wed Aug 29 16:46:38 2012
New Revision: 162868

URL: http://llvm.org/viewvc/llvm-project?rev=162868&view=rev
Log:
Preserve branch profile metadata during switch formation.

Patch by Michael Ilseman!
This fixes SimplifyCFGOpt::FoldValueComparisonIntoPredecessors to preserve metata when folding conditional branches into switches.

void foo(int x) {
  if (x == 0)
    bar(1);
  else if (__builtin_expect(x == 10, 1))
    bar(2);
  else if (x == 20)
    bar(3);
}

CFG:

B0
|  \
|   X0
B10
|  \
|   X10
B20
|  \
E   X20

Merge B0-B10:
w(B0-X0) = w(B0-X0)*sum-weights(B10) = w(B0-X0) * (w(B10-X10) + w(B10-B20))
w(B0-X10) = w(B0-B10) * w(B10-X10)
w(B0-B20) = w(B0-B10) * w(B10-B20)

B0 __
| \  \
| X10 X0
B20
|  \
E  X20

Merge B0-B20:
w(B0-X0) = w(B0-X0) * sum-weights(B20) = w(B0-X0) * (w(B20-E) + w(B20-X20))
w(B0-X10) = w(B0-X10) * sum-weights(B20) = ...
w(B0-X20) = w(B0-B20) * w(B20-X20)
w(B0-E) = w(B0-B20) * w(B20-E)

Modified:
    llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp

Modified: llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp?rev=162868&r1=162867&r2=162868&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp Wed Aug 29 16:46:38 2012
@@ -615,6 +615,9 @@
   assert(ThisVal && "This isn't a value comparison!!");
   if (ThisVal != PredVal) return false;  // Different predicates.
 
+  // TODO: Preserve branch weight metadata, similarly to how
+  // FoldValueComparisonIntoPredecessors preserves it.
+
   // Find out information about when control will move from Pred to TI's block.
   std::vector<ValueEqualityComparisonCase> PredCases;
   BasicBlock *PredDef = GetValueEqualityComparisonCases(Pred->getTerminator(),
@@ -738,6 +741,67 @@
   return -1;
 }
 
+static inline bool HasBranchWeights(const Instruction* I) {
+  MDNode* ProfMD = I->getMetadata(LLVMContext::MD_prof);
+  if (ProfMD && ProfMD->getOperand(0))
+    if (MDString* MDS = dyn_cast<MDString>(ProfMD->getOperand(0)))
+      return MDS->getString().equals("branch_weights");
+
+  return false;
+}
+
+/// Tries to get a branch weight for the given instruction, returns NULL if it
+/// can't. Pos starts at 0.
+static ConstantInt* GetWeight(Instruction* I, int Pos) {
+  MDNode* ProfMD = I->getMetadata(LLVMContext::MD_prof);
+  if (ProfMD && ProfMD->getOperand(0)) {
+    if (MDString* MDS = dyn_cast<MDString>(ProfMD->getOperand(0))) {
+      if (MDS->getString().equals("branch_weights")) {
+        assert(ProfMD->getNumOperands() >= 3);
+        return dyn_cast<ConstantInt>(ProfMD->getOperand(1 + Pos));
+      }
+    }
+  }
+
+  return 0;
+}
+
+/// Scale the given weights based on the new TI's metadata. Scaling is done by
+/// multiplying every weight by the sum of the successor's weights.
+static void ScaleWeights(Instruction* STI, MutableArrayRef<uint64_t> Weights) {
+  // Sum the successor's weights
+  assert(HasBranchWeights(STI));
+  unsigned Scale = 0;
+  MDNode* ProfMD = STI->getMetadata(LLVMContext::MD_prof);
+  for (unsigned i = 1; i < ProfMD->getNumOperands(); ++i) {
+    ConstantInt* CI = dyn_cast<ConstantInt>(ProfMD->getOperand(i));
+    assert(CI);
+    Scale += CI->getValue().getZExtValue();
+  }
+
+  // Skip default, as it's replaced during the folding
+  for (unsigned i = 1; i < Weights.size(); ++i) {
+    Weights[i] *= Scale;
+  }
+}
+
+/// Sees if any of the weights are too big for a uint32_t, and halves all the
+/// weights if any are.
+static void FitWeights(MutableArrayRef<uint64_t> Weights) {
+  bool Halve = false;
+  for (unsigned i = 0; i < Weights.size(); ++i)
+    if (Weights[i] > UINT_MAX) {
+      Halve = true;
+      break;
+    }
+
+  if (! Halve)
+    return;
+
+  for (unsigned i = 0; i < Weights.size(); ++i)
+    Weights[i] /= 2;
+}
+
 /// FoldValueComparisonIntoPredecessors - The specified terminator is a value
 /// equality comparison instruction (either a switch or a branch on "X == c").
 /// See if any of the predecessors of the terminator block are value comparisons
@@ -770,6 +834,55 @@
       // build.
       SmallVector<BasicBlock*, 8> NewSuccessors;
 
+      // Update the branch weight metadata along the way
+      SmallVector<uint64_t, 8> Weights;
+      uint64_t PredDefaultWeight = 0;
+      bool PredHasWeights = HasBranchWeights(PTI);
+      bool SuccHasWeights = HasBranchWeights(TI);
+
+      if (PredHasWeights) {
+        MDNode* MD = PTI->getMetadata(LLVMContext::MD_prof);
+        assert(MD);
+        for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
+          ConstantInt* CI = dyn_cast<ConstantInt>(MD->getOperand(i));
+          assert(CI);
+          Weights.push_back(CI->getValue().getZExtValue());
+        }
+
+        // If the predecessor is a conditional eq, then swap the default weight
+        // to be the first entry.
+        if (BranchInst* BI = dyn_cast<BranchInst>(PTI)) {
+          assert(Weights.size() == 2);
+          ICmpInst *ICI = cast<ICmpInst>(BI->getCondition());
+
+          if (ICI->getPredicate() == ICmpInst::ICMP_EQ) {
+            std::swap(Weights.front(), Weights.back());
+          }
+        }
+
+        PredDefaultWeight = Weights.front();
+      } else if (SuccHasWeights) {
+        // If there are no predecessor weights but there are successor weights,
+        // populate Weights with 1, which will later be scaled to the sum of
+        // successor's weights
+        Weights.assign(1 + PredCases.size(), 1);
+        PredDefaultWeight = 1;
+      }
+
+      uint64_t SuccDefaultWeight = 0;
+      if (SuccHasWeights) {
+        int Index = 0;
+        if (BranchInst* BI = dyn_cast<BranchInst>(TI)) {
+          ICmpInst* ICI = dyn_cast<ICmpInst>(BI->getCondition());
+          assert(ICI);
+
+          if (ICI->getPredicate() == ICmpInst::ICMP_EQ)
+            Index = 1;
+        }
+
+        SuccDefaultWeight = GetWeight(TI, Index)->getValue().getZExtValue();
+      }
+
       if (PredDefault == BB) {
         // If this is the default destination from PTI, only the edges in TI
         // that don't occur in PTI, or that branch to BB will be activated.
@@ -780,6 +893,12 @@
           else {
             // The default destination is BB, we don't need explicit targets.
             std::swap(PredCases[i], PredCases.back());
+
+            if (PredHasWeights) {
+              std::swap(Weights[i+1], Weights.back());
+              Weights.pop_back();
+            }
+
             PredCases.pop_back();
             --i; --e;
           }
@@ -790,14 +909,35 @@
           PredDefault = BBDefault;
           NewSuccessors.push_back(BBDefault);
         }
+
+        if (SuccHasWeights) {
+          ScaleWeights(TI, Weights);
+          Weights.front() *= SuccDefaultWeight;
+        } else if (PredHasWeights) {
+          Weights.front() /= (1 + BBCases.size());
+        }
+
         for (unsigned i = 0, e = BBCases.size(); i != e; ++i)
           if (!PTIHandled.count(BBCases[i].Value) &&
               BBCases[i].Dest != BBDefault) {
             PredCases.push_back(BBCases[i]);
             NewSuccessors.push_back(BBCases[i].Dest);
+            if (SuccHasWeights) {
+              Weights.push_back(PredDefaultWeight *
+                                GetWeight(TI, i)->getValue().getZExtValue());
+            } else if (PredHasWeights) {
+              // Split the old default's weight amongst the children
+              assert(PredDefaultWeight != 0);
+              Weights.push_back(PredDefaultWeight / (1 + BBCases.size()));
+            }
           }
 
       } else {
+        // FIXME: preserve branch weight metadata, similarly to the 'then'
+        // above. For now, drop it.
+        PredHasWeights = false;
+        SuccHasWeights = false;
+
         // If this is not the default destination from PSI, only the edges
         // in SI that occur in PSI with a destination of BB will be
         // activated.
@@ -851,6 +991,17 @@
       for (unsigned i = 0, e = PredCases.size(); i != e; ++i)
         NewSI->addCase(PredCases[i].Value, PredCases[i].Dest);
 
+      if (PredHasWeights || SuccHasWeights) {
+        // Halve the weights if any of them cannot fit in an uint32_t
+        FitWeights(Weights);
+
+        SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end());
+
+        NewSI->setMetadata(LLVMContext::MD_prof,
+                           MDBuilder(BB->getContext()).
+                           createBranchWeights(MDWeights));
+      }
+
       EraseTerminatorInstAndDCECond(PTI);
 
       // Okay, last check.  If BB is still a successor of PSI, then we must
@@ -2349,6 +2500,9 @@
   // transformation.  A switch with one value is just an cond branch.
   if (ExtraCase && Values.size() < 2) return false;
 
+  // TODO: Preserve branch weight metadata, similarly to how
+  // FoldValueComparisonIntoPredecessors preserves it.
+
   // Figure out which block is which destination.
   BasicBlock *DefaultBB = BI->getSuccessor(1);
   BasicBlock *EdgeBB    = BI->getSuccessor(0);





More information about the llvm-commits mailing list