[llvm-commits] [llvm] r163954 - in /llvm/trunk: lib/Transforms/Utils/SimplifyCFG.cpp test/Transforms/SimplifyCFG/preserve-branchweights.ll

Manman Ren mren at apple.com
Fri Sep 14 17:39:57 PDT 2012


Author: mren
Date: Fri Sep 14 19:39:57 2012
New Revision: 163954

URL: http://llvm.org/viewvc/llvm-project?rev=163954&view=rev
Log:
PGO: preserve branch-weight metadata when simplifying two branches with a common
destination.

Updated previous implementation to fix a case not covered:
// PBI: br i1 %x, TrueDest, BB
// BI:  br i1 %y, TrueDest, FalseDest
The other case was handled correctly.
// PBI: br i1 %x, BB, FalseDest
// BI:  br i1 %y, TrueDest, FalseDest

Also tried to use 64-bit arithmetic instead of APInt with scale to simplify the
computation. Let me know if you have other opinions about this.

Modified:
    llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp
    llvm/trunk/test/Transforms/SimplifyCFG/preserve-branchweights.ll

Modified: llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp?rev=163954&r1=163953&r2=163954&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/SimplifyCFG.cpp Fri Sep 14 19:39:57 2012
@@ -1658,7 +1658,7 @@
 /// parameters and return true, or returns false if no or invalid metadata was
 /// found.
 static bool ExtractBranchMetadata(BranchInst *BI,
-                                  APInt &ProbTrue, APInt &ProbFalse) {
+                                  uint64_t &ProbTrue, uint64_t &ProbFalse) {
   assert(BI->isConditional() &&
          "Looking for probabilities on unconditional branch?");
   MDNode *ProfileData = BI->getMetadata(LLVMContext::MD_prof);
@@ -1666,35 +1666,11 @@
   ConstantInt *CITrue = dyn_cast<ConstantInt>(ProfileData->getOperand(1));
   ConstantInt *CIFalse = dyn_cast<ConstantInt>(ProfileData->getOperand(2));
   if (!CITrue || !CIFalse) return false;
-  ProbTrue = CITrue->getValue();
-  ProbFalse = CIFalse->getValue();
-  assert(ProbTrue.getBitWidth() == 32 && ProbFalse.getBitWidth() == 32 &&
-         "Branch probability metadata must be 32-bit integers");
+  ProbTrue = CITrue->getValue().getZExtValue();
+  ProbFalse = CIFalse->getValue().getZExtValue();
   return true;
 }
 
-/// MultiplyAndLosePrecision - Multiplies A and B, then returns the result. In
-/// the event of overflow, logically-shifts all four inputs right until the
-/// multiply fits.
-static APInt MultiplyAndLosePrecision(APInt &A, APInt &B, APInt &C, APInt &D,
-                                      unsigned &BitsLost) {
-  BitsLost = 0;
-  bool Overflow = false;
-  APInt Result = A.umul_ov(B, Overflow);
-  if (Overflow) {
-    APInt MaxB = APInt::getMaxValue(A.getBitWidth()).udiv(A);
-    do {
-      B = B.lshr(1);
-      ++BitsLost;
-    } while (B.ugt(MaxB));
-    A = A.lshr(BitsLost);
-    C = C.lshr(BitsLost);
-    D = D.lshr(BitsLost);
-    Result = A * B;
-  }
-  return Result;
-}
-
 /// checkCSEInPredecessor - Return true if the given instruction is available
 /// in its predecessor block. If yes, the instruction will be removed.
 ///
@@ -1919,14 +1895,53 @@
                                             New, "or.cond"));
       PBI->setCondition(NewCond);
 
+      uint64_t PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight;
+      bool PredHasWeights = ExtractBranchMetadata(PBI, PredTrueWeight,
+                                                  PredFalseWeight);
+      bool SuccHasWeights = ExtractBranchMetadata(BI, SuccTrueWeight,
+                                                  SuccFalseWeight);
+      SmallVector<uint64_t, 8> NewWeights;
+
       if (PBI->getSuccessor(0) == BB) {
+        if (PredHasWeights && SuccHasWeights) {
+          // PBI: br i1 %x, BB, FalseDest
+          // BI:  br i1 %y, TrueDest, FalseDest
+          //TrueWeight is TrueWeight for PBI * TrueWeight for BI.
+          NewWeights.push_back(PredTrueWeight * SuccTrueWeight);
+          //FalseWeight is FalseWeight for PBI * TotalWeight for BI +
+          //               TrueWeight for PBI * FalseWeight for BI.
+          // We assume that total weights of a BranchInst can fit into 32 bits.
+          // Therefore, we will not have overflow using 64-bit arithmetic.
+          NewWeights.push_back(PredFalseWeight * (SuccFalseWeight +
+               SuccTrueWeight) + PredTrueWeight * SuccFalseWeight);
+        }
         AddPredecessorToBlock(TrueDest, PredBlock, BB);
         PBI->setSuccessor(0, TrueDest);
       }
       if (PBI->getSuccessor(1) == BB) {
+        if (PredHasWeights && SuccHasWeights) {
+          // PBI: br i1 %x, TrueDest, BB
+          // BI:  br i1 %y, TrueDest, FalseDest
+          //TrueWeight is TrueWeight for PBI * TotalWeight for BI +
+          //              FalseWeight for PBI * TrueWeight for BI.
+          NewWeights.push_back(PredTrueWeight * (SuccFalseWeight +
+              SuccTrueWeight) + PredFalseWeight * SuccTrueWeight);
+          //FalseWeight is FalseWeight for PBI * FalseWeight for BI.
+          NewWeights.push_back(PredFalseWeight * SuccFalseWeight);
+        }
         AddPredecessorToBlock(FalseDest, PredBlock, BB);
         PBI->setSuccessor(1, FalseDest);
       }
+      if (NewWeights.size() == 2) {
+        // Halve the weights if any of them cannot fit in an uint32_t
+        FitWeights(NewWeights);
+
+        SmallVector<uint32_t, 8> MDWeights(NewWeights.begin(),NewWeights.end());
+        PBI->setMetadata(LLVMContext::MD_prof,
+                         MDBuilder(BI->getContext()).
+                         createBranchWeights(MDWeights));
+      } else
+        PBI->setMetadata(LLVMContext::MD_prof, NULL);
     } else {
       // Update PHI nodes in the common successors.
       for (unsigned i = 0, e = PHIs.size(); i != e; ++i) {
@@ -1981,90 +1996,6 @@
     // TODO: If BB is reachable from all paths through PredBlock, then we
     // could replace PBI's branch probabilities with BI's.
 
-    // Merge probability data into PredBlock's branch.
-    APInt A, B, C, D;
-    if (PBI->isConditional() && BI->isConditional() &&
-        ExtractBranchMetadata(PBI, C, D) && ExtractBranchMetadata(BI, A, B)) {
-      // Given IR which does:
-      //   bbA:
-      //     br i1 %x, label %bbB, label %bbC
-      //   bbB:
-      //     br i1 %y, label %bbD, label %bbC
-      // Let's call the probability that we take the edge from %bbA to %bbB
-      // 'a', from %bbA to %bbC, 'b', from %bbB to %bbD 'c' and from %bbB to
-      // %bbC probability 'd'.
-      //
-      // We transform the IR into:
-      //   bbA:
-      //     br i1 %z, label %bbD, label %bbC
-      // where the probability of going to %bbD is (a*c) and going to bbC is
-      // (b+a*d).
-      //
-      // Probabilities aren't stored as ratios directly. Using branch weights,
-      // we get:
-      // (a*c)% = A*C, (b+(a*d))% = A*D+B*C+B*D.
-
-      // In the event of overflow, we want to drop the LSB of the input
-      // probabilities.
-      unsigned BitsLost;
-
-      // Ignore overflow result on ProbTrue.
-      APInt ProbTrue = MultiplyAndLosePrecision(A, C, B, D, BitsLost);
-
-      APInt Tmp1 = MultiplyAndLosePrecision(B, D, A, C, BitsLost);
-      if (BitsLost) {
-        ProbTrue = ProbTrue.lshr(BitsLost*2);
-      }
-
-      APInt Tmp2 = MultiplyAndLosePrecision(A, D, C, B, BitsLost);
-      if (BitsLost) {
-        ProbTrue = ProbTrue.lshr(BitsLost*2);
-        Tmp1 = Tmp1.lshr(BitsLost*2);
-      }
-
-      APInt Tmp3 = MultiplyAndLosePrecision(B, C, A, D, BitsLost);
-      if (BitsLost) {
-        ProbTrue = ProbTrue.lshr(BitsLost*2);
-        Tmp1 = Tmp1.lshr(BitsLost*2);
-        Tmp2 = Tmp2.lshr(BitsLost*2);
-      }
-
-      bool Overflow1 = false, Overflow2 = false;
-      APInt Tmp4 = Tmp2.uadd_ov(Tmp3, Overflow1);
-      APInt ProbFalse = Tmp4.uadd_ov(Tmp1, Overflow2);
-
-      if (Overflow1 || Overflow2) {
-        ProbTrue = ProbTrue.lshr(1);
-        Tmp1 = Tmp1.lshr(1);
-        Tmp2 = Tmp2.lshr(1);
-        Tmp3 = Tmp3.lshr(1);
-        Tmp4 = Tmp2 + Tmp3;
-        ProbFalse = Tmp4 + Tmp1;
-      }
-
-      // The sum of branch weights must fit in 32-bits.
-      if (ProbTrue.isNegative() && ProbFalse.isNegative()) {
-        ProbTrue = ProbTrue.lshr(1);
-        ProbFalse = ProbFalse.lshr(1);
-      }
-
-      if (ProbTrue != ProbFalse) {
-        // Normalize the result.
-        APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse);
-        ProbTrue = ProbTrue.udiv(GCD);
-        ProbFalse = ProbFalse.udiv(GCD);
-
-        MDBuilder MDB(BI->getContext());
-        MDNode *N = MDB.createBranchWeights(ProbTrue.getZExtValue(),
-                                            ProbFalse.getZExtValue());
-        PBI->setMetadata(LLVMContext::MD_prof, N);
-      } else {
-        PBI->setMetadata(LLVMContext::MD_prof, NULL);
-      }
-    } else {
-      PBI->setMetadata(LLVMContext::MD_prof, NULL);
-    }
-
     // Copy any debug value intrinsics into the end of PredBlock.
     for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
       if (isa<DbgInfoIntrinsic>(*I))

Modified: llvm/trunk/test/Transforms/SimplifyCFG/preserve-branchweights.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/SimplifyCFG/preserve-branchweights.ll?rev=163954&r1=163953&r2=163954&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/SimplifyCFG/preserve-branchweights.ll (original)
+++ llvm/trunk/test/Transforms/SimplifyCFG/preserve-branchweights.ll Fri Sep 14 19:39:57 2012
@@ -154,6 +154,26 @@
   ret void
 }
 
+;; This test is based on test1 but swapped the targets of the second branch.
+define void @test1_swap(i1 %a, i1 %b) {
+; CHECK: @test1_swap
+entry:
+  br i1 %a, label %Y, label %X, !prof !0
+; CHECK: br i1 %or.cond, label %Y, label %Z, !prof !4
+
+X:
+  %c = or i1 %b, false
+  br i1 %c, label %Y, label %Z, !prof !1
+
+Y:
+  call void @helper(i32 0)
+  ret void
+
+Z:
+  call void @helper(i32 1)
+  ret void
+}
+
 !0 = metadata !{metadata !"branch_weights", i32 3, i32 5}
 !1 = metadata !{metadata !"branch_weights", i32 1, i32 1}
 !2 = metadata !{metadata !"branch_weights", i32 1, i32 2}
@@ -165,4 +185,5 @@
 ; CHECK: !1 = metadata !{metadata !"branch_weights", i32 1, i32 5}
 ; CHECK: !2 = metadata !{metadata !"branch_weights", i32 7, i32 1, i32 2}
 ; CHECK: !3 = metadata !{metadata !"branch_weights", i32 49, i32 12, i32 24, i32 35}
-; CHECK-NOT: !4
+; CHECK: !4 = metadata !{metadata !"branch_weights", i32 11, i32 5}
+; CHECK-NOT: !5





More information about the llvm-commits mailing list