[llvm-commits] [llvm] r163635 - in /llvm/trunk: lib/Transforms/Utils/SimplifyCFG.cpp test/Transforms/SimplifyCFG/preserve-branchweights-switch-create.ll

Manman Ren mren at apple.com
Tue Sep 11 10:43:35 PDT 2012


Author: mren
Date: Tue Sep 11 12:43:35 2012
New Revision: 163635

URL: http://llvm.org/viewvc/llvm-project?rev=163635&view=rev
Log:
SimplifyCFG: preserve branch-weight metadata when creating a new switch from
a pair of switch/branch where both depend on the value of the same variable and
the default case of the first switch/branch goes to the second switch/branch.

Code clean up and fixed a few issues:
1> handling the case where some cases of the 2nd switch are invalidated
2> correctly calculate the weight for the 2nd switch when it is a conditional eq

Testing case is modified from Alastair's original patch.

Added:
    llvm/trunk/test/Transforms/SimplifyCFG/preserve-branchweights-switch-create.ll
Modified:
    llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp

Modified: llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp?rev=163635&r1=163634&r2=163635&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp Tue Sep 11 12:43:35 2012
@@ -752,38 +752,27 @@
   return false;
 }
 
-/// Tries to get a branch weight for the given instruction, returns NULL if it
-/// can't. Pos starts at 0.
-static ConstantInt* GetWeight(Instruction* I, int Pos) {
-  MDNode* ProfMD = I->getMetadata(LLVMContext::MD_prof);
-  if (ProfMD && ProfMD->getOperand(0)) {
-    if (MDString* MDS = dyn_cast<MDString>(ProfMD->getOperand(0))) {
-      if (MDS->getString().equals("branch_weights")) {
-        assert(ProfMD->getNumOperands() >= 3);
-        return dyn_cast<ConstantInt>(ProfMD->getOperand(1 + Pos));
-      }
-    }
-  }
-
-  return 0;
-}
-
-/// Scale the given weights based on the successor TI's metadata. Scaling is
-/// done by multiplying every weight by the sum of the successor's weights.
-static void ScaleWeights(Instruction* STI, MutableArrayRef<uint64_t> Weights) {
-  // Sum the successor's weights
-  assert(HasBranchWeights(STI));
-  unsigned Scale = 0;
-  MDNode* ProfMD = STI->getMetadata(LLVMContext::MD_prof);
-  for (unsigned i = 1; i < ProfMD->getNumOperands(); ++i) {
-    ConstantInt* CI = dyn_cast<ConstantInt>(ProfMD->getOperand(i));
+/// Get Weights of a given TerminatorInst, the default weight is at the front
+/// of the vector. If TI is a conditional eq, we need to swap the branch-weight
+/// metadata.
+static void GetBranchWeights(TerminatorInst *TI,
+                             SmallVectorImpl<uint64_t> &Weights) {
+  MDNode* MD = TI->getMetadata(LLVMContext::MD_prof);
+  assert(MD);
+  for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
+    ConstantInt* CI = dyn_cast<ConstantInt>(MD->getOperand(i));
     assert(CI);
-    Scale += CI->getValue().getZExtValue();
+    Weights.push_back(CI->getValue().getZExtValue());
   }
 
-  // Skip default, as it's replaced during the folding
-  for (unsigned i = 1; i < Weights.size(); ++i) {
-    Weights[i] *= Scale;
+  // If TI is a conditional eq, the default case is the false case,
+  // and the corresponding branch-weight data is at index 2. We swap the
+  // default weight to be the first entry.
+  if (BranchInst* BI = dyn_cast<BranchInst>(TI)) {
+    assert(Weights.size() == 2);
+    ICmpInst *ICI = cast<ICmpInst>(BI->getCondition());
+    if (ICI->getPredicate() == ICmpInst::ICMP_EQ)
+      std::swap(Weights.front(), Weights.back());
   }
 }
 
@@ -838,52 +827,22 @@
 
       // Update the branch weight metadata along the way
       SmallVector<uint64_t, 8> Weights;
-      uint64_t PredDefaultWeight = 0;
       bool PredHasWeights = HasBranchWeights(PTI);
       bool SuccHasWeights = HasBranchWeights(TI);
 
-      if (PredHasWeights) {
-        MDNode* MD = PTI->getMetadata(LLVMContext::MD_prof);
-        assert(MD);
-        for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
-          ConstantInt* CI = dyn_cast<ConstantInt>(MD->getOperand(i));
-          assert(CI);
-          Weights.push_back(CI->getValue().getZExtValue());
-        }
-
-        // If the predecessor is a conditional eq, then swap the default weight
-        // to be the first entry.
-        if (BranchInst* BI = dyn_cast<BranchInst>(PTI)) {
-          assert(Weights.size() == 2);
-          ICmpInst *ICI = cast<ICmpInst>(BI->getCondition());
-
-          if (ICI->getPredicate() == ICmpInst::ICMP_EQ) {
-            std::swap(Weights.front(), Weights.back());
-          }
-        }
-
-        PredDefaultWeight = Weights.front();
-      } else if (SuccHasWeights) {
+      if (PredHasWeights)
+        GetBranchWeights(PTI, Weights);
+      else if (SuccHasWeights)
         // If there are no predecessor weights but there are successor weights,
         // populate Weights with 1, which will later be scaled to the sum of
         // successor's weights
         Weights.assign(1 + PredCases.size(), 1);
-        PredDefaultWeight = 1;
-      }
 
-      uint64_t SuccDefaultWeight = 0;
-      if (SuccHasWeights) {
-        int Index = 0;
-        if (BranchInst* BI = dyn_cast<BranchInst>(TI)) {
-          ICmpInst* ICI = dyn_cast<ICmpInst>(BI->getCondition());
-          assert(ICI);
-
-          if (ICI->getPredicate() == ICmpInst::ICMP_EQ)
-            Index = 1;
-        }
-
-        SuccDefaultWeight = GetWeight(TI, Index)->getValue().getZExtValue();
-      }
+      SmallVector<uint64_t, 8> SuccWeights;
+      if (SuccHasWeights)
+        GetBranchWeights(TI, SuccWeights);
+      else if (PredHasWeights)
+        SuccWeights.assign(1 + BBCases.size(), 1);
 
       if (PredDefault == BB) {
         // If this is the default destination from PTI, only the edges in TI
@@ -896,7 +855,9 @@
             // The default destination is BB, we don't need explicit targets.
             std::swap(PredCases[i], PredCases.back());
 
-            if (PredHasWeights) {
+            if (PredHasWeights || SuccHasWeights) {
+              // Increase weight for the default case.
+              Weights[0] += Weights[i+1];
               std::swap(Weights[i+1], Weights.back());
               Weights.pop_back();
             }
@@ -912,27 +873,30 @@
           NewSuccessors.push_back(BBDefault);
         }
 
-        if (SuccHasWeights) {
-          ScaleWeights(TI, Weights);
-          Weights.front() *= SuccDefaultWeight;
-        } else if (PredHasWeights) {
-          Weights.front() /= (1 + BBCases.size());
-        }
-
+        unsigned CasesFromPred = Weights.size();
+        uint64_t ValidTotalSuccWeight = 0;
         for (unsigned i = 0, e = BBCases.size(); i != e; ++i)
           if (!PTIHandled.count(BBCases[i].Value) &&
               BBCases[i].Dest != BBDefault) {
             PredCases.push_back(BBCases[i]);
             NewSuccessors.push_back(BBCases[i].Dest);
-            if (SuccHasWeights) {
-              Weights.push_back(PredDefaultWeight *
-                                GetWeight(TI, i)->getValue().getZExtValue());
-            } else if (PredHasWeights) {
-              // Split the old default's weight amongst the children
-              Weights.push_back(PredDefaultWeight / (1 + BBCases.size()));
+            if (SuccHasWeights || PredHasWeights) {
+              // The default weight is at index 0, so weight for the ith case
+              // should be at index i+1. Scale the cases from successor by
+              // PredDefaultWeight (Weights[0]).
+              Weights.push_back(Weights[0] * SuccWeights[i+1]);
+              ValidTotalSuccWeight += SuccWeights[i+1];
             }
           }
 
+        if (SuccHasWeights || PredHasWeights) {
+          ValidTotalSuccWeight += SuccWeights[0];
+          // Scale the cases from predecessor by ValidTotalSuccWeight.
+          for (unsigned i = 1; i < CasesFromPred; ++i)
+            Weights[i] *= ValidTotalSuccWeight;
+          // Scale the default weight by SuccDefaultWeight (SuccWeights[0]).
+          Weights[0] *= SuccWeights[0];
+        }
       } else {
         // FIXME: preserve branch weight metadata, similarly to the 'then'
         // above. For now, drop it.

Added: llvm/trunk/test/Transforms/SimplifyCFG/preserve-branchweights-switch-create.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/SimplifyCFG/preserve-branchweights-switch-create.ll?rev=163635&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/SimplifyCFG/preserve-branchweights-switch-create.ll (added)
+++ llvm/trunk/test/Transforms/SimplifyCFG/preserve-branchweights-switch-create.ll Tue Sep 11 12:43:35 2012
@@ -0,0 +1,92 @@
+; RUN: opt -simplifycfg -S -o - < %s | FileCheck %s
+
+declare void @func2(i32)
+declare void @func4(i32)
+declare void @func6(i32)
+declare void @func8(i32)
+
+;; test1 - create a switch with case 2 and case 4 from two branches: N == 2
+;; and N == 4.
+define void @test1(i32 %N) nounwind uwtable {
+entry:
+  %cmp = icmp eq i32 %N, 2
+  br i1 %cmp, label %if.then, label %if.else, !prof !0
+; CHECK: test1
+; CHECK: switch i32 %N
+; CHECK: ], !prof !0
+
+if.then:
+  call void @func2(i32 %N) nounwind
+  br label %if.end9
+
+if.else:
+  %cmp2 = icmp eq i32 %N, 4
+  br i1 %cmp2, label %if.then7, label %if.else8, !prof !1
+
+if.then7:
+  call void @func4(i32 %N) nounwind
+  br label %if.end
+
+if.else8:
+  call void @func8(i32 %N) nounwind
+  br label %if.end
+
+if.end:
+  br label %if.end9
+
+if.end9:
+  ret void
+}
+
+;; test2 - Merge two switches where PredDefault == BB.
+define void @test2(i32 %M, i32 %N) nounwind uwtable {
+entry:
+  %cmp = icmp sgt i32 %M, 2
+  br i1 %cmp, label %sw1, label %sw2
+
+sw1:
+  switch i32 %N, label %sw2 [
+    i32 2, label %sw.bb
+    i32 3, label %sw.bb1
+  ], !prof !2
+; CHECK: test2
+; CHECK: switch i32 %N, label %sw.epilog
+; CHECK: i32 2, label %sw.bb
+; CHECK: i32 3, label %sw.bb1
+; CHECK: i32 4, label %sw.bb5
+; CHECK: ], !prof !1
+
+sw.bb:
+  call void @func2(i32 %N) nounwind
+  br label %sw.epilog
+
+sw.bb1:
+  call void @func4(i32 %N) nounwind
+  br label %sw.epilog
+
+sw2:
+;; Here "case 2" is invalidated if control is transferred through default case
+;; of the first switch.
+  switch i32 %N, label %sw.epilog [
+    i32 2, label %sw.bb4
+    i32 4, label %sw.bb5
+  ], !prof !3
+
+sw.bb4:
+  call void @func6(i32 %N) nounwind
+  br label %sw.epilog
+
+sw.bb5:
+  call void @func8(i32 %N) nounwind
+  br label %sw.epilog
+
+sw.epilog:
+  ret void
+}
+
+!0 = metadata !{metadata !"branch_weights", i32 64, i32 4}
+!1 = metadata !{metadata !"branch_weights", i32 4, i32 64}
+; CHECK: !0 = metadata !{metadata !"branch_weights", i32 256, i32 4352, i32 16}
+!2 = metadata !{metadata !"branch_weights", i32 4, i32 4, i32 8}
+!3 = metadata !{metadata !"branch_weights", i32 8, i32 8, i32 4}
+; CHECK: !1 = metadata !{metadata !"branch_weights", i32 32, i32 48, i32 96, i32 16}





More information about the llvm-commits mailing list