[PATCH] Fix information loss in branch probability computation.

Duncan Exon Smith dexonsmith at apple.com
Fri May 1 19:42:45 PDT 2015



> On May 1, 2015, at 7:23 PM, Duncan P. N. Exon Smith <dexonsmith at apple.com> wrote:
> 
> 
>> On 2015 May 1, at 13:21, Diego Novillo <dnovillo at google.com> wrote:
>> 
>> Hi dexonsmith,
>> 
>> This addresses PR 22718. When branch weights are too large, they were
>> being clamped to the range [1, MaxWeightForBB]. But this clamping is
>> only applied to edges that go outside the range, so it distorts the
>> relative branch probabilities.
>> 
>> This patch changes the weight calculation to scale every branch so the
>> relative probabilities are preserved.
>> 
>> The patch fixes an existing test that had slightly wrong branch
>> probabilities due to the previous clamping. It now gets branch weights
>> scaled accordingly.
>> 
>> http://reviews.llvm.org/D9442
>> 
>> Files:
>> lib/Analysis/BranchProbabilityInfo.cpp
>> test/Analysis/BranchProbabilityInfo/pr22718.ll
>> test/CodeGen/X86/MachineBranchProb.ll
>> 
>> EMAIL PREFERENCES
>> http://reviews.llvm.org/settings/panel/emailpreferences/
>> <D9442.24827.patch>
> 
>> Index: lib/Analysis/BranchProbabilityInfo.cpp
>> ===================================================================
>> --- lib/Analysis/BranchProbabilityInfo.cpp
>> +++ lib/Analysis/BranchProbabilityInfo.cpp
>> @@ -21,6 +21,7 @@
>> #include "llvm/IR/LLVMContext.h"
>> #include "llvm/IR/Metadata.h"
>> #include "llvm/Support/Debug.h"
>> +#include "llvm/Support/ScaledNumber.h"
>> #include "llvm/Support/raw_ostream.h"
>> 
>> using namespace llvm;
>> @@ -119,6 +120,26 @@
>>   return UINT32_MAX / BB->getTerminator()->getNumSuccessors();
>> }
>> 
>> +/// \brief Determine the maximum branch weight going out of a block.
>> +///
>> +/// This returns the maximum branch weight annotation found in the
>> +/// given MD_prof annotation \p Weights.
>> +static uint32_t getMaxUnscaledWeightFor(const MDNode *WeightsNode) {
>> +  uint32_t Max = 0;
>> +  for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
>> +    ConstantInt *Weight =
>> +        mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
>> +    if (Weight) {
>> +      assert(Weight->getValue().getActiveBits() <= 32 &&
>> +             "Too many bits for uint32_t");
>> +      uint32_t Val = static_cast<uint32_t>(Weight->getZExtValue());
>> +      if (Val > Max)
>> +        Max = Val;
>> +    }
>> +  }
>> +  assert(Max > 0 && "Maximum branch weight should not be zero.");
>> +  return Max;
>> +}
>> 
>> /// \brief Calculate edge weights for successors lead to unreachable.
>> ///
>> @@ -191,18 +212,25 @@
>>     return false;
>> 
>>   // Build up the final weights that will be used in a temporary buffer, but
>> -  // don't add them until all weights are present. Each weight value is clamped
>> -  // to [1, getMaxWeightFor(BB)].
>> +  // don't add them until all weights are present. Each weight value is scaled
>> +  // to the range [1, getMaxWeightFor(BB)].
>>   uint32_t WeightLimit = getMaxWeightFor(BB);
>> +  uint32_t MaxUnscaledWeight = getMaxUnscaledWeightFor(WeightsNode);
>> +  typedef ScaledNumber<uint32_t> Scaled32;
>> +  Scaled32 ScalingFactor =
>> +      (MaxUnscaledWeight > WeightLimit)
>> +          ? Scaled32(WeightLimit, 0) / Scaled32(MaxUnscaledWeight, 0)
>> +          : Scaled32::getOne();
>>   SmallVector<uint32_t, 2> Weights;
>>   Weights.reserve(TI->getNumSuccessors());
>>   for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
>>     ConstantInt *Weight =
>>         mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
>>     if (!Weight)
>>       return false;
>> -    Weights.push_back(
>> -      std::max<uint32_t>(1, Weight->getLimitedValue(WeightLimit)));
>> +    uint32_t ScaledWeight =
>> +        (Scaled32(Weight->getZExtValue(), 0) * ScalingFactor).toInt<uint32_t>();
>> +    Weights.push_back(std::max<uint32_t>(1, ScaledWeight));
>>   }
> 
> Hmm.  Now I see why you thought `ScaledNumber` should maybe have some
> short circults.
> 
> I feel like `ScaledNumber` is too heavy-weight for this.  In BFI the
> global scale of the numbers is hard to predict, but here we're just
> dealing with a single basic block.  All the numbers are small.
> 
> I think this could be fixed more simply by doing something similar to
> MachineBranchProbabilityInfo::getSumForBlock():
> 
>    // Check that we have a sane number of successors.
>    assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
> 
>    // Copy out the weights, accumulating a 64-bit sum.
>    uint64_t Sum = 0;
>    SmallVector<uint32_t, 2> Weights;
>    Weights.reserve(TI->getNumSuccessors());
>    for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) {
>      ConstantInt *Weight =
>          mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
>      if (!Weight)
>        return false;
>      assert(Weight->getValue().getActiveBits() <= 32 &&
>             "Too many bits for uint32_t");
>      Weights.push_back(std::max(1, Weight->getZExtValue()));
>      Sum += Weights.back();
>    }
>    assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
> 
>    // Scale the weights down if the sum exceeds UINT32_MAX.
>    if (Sum > UINT32_MAX) {
>      uint64_t Scale = Sum / UINT32_MAX + 1;
>      Sum = 0;
>      for (auto &W : Weights)
>        Sum += W = std::max(1, W / Scale);
>      assert(Sum <= UINT32_MAX && "Expected sum to get scaled down");

Actually, maybe this assert could fire because of the floor of 1.  (Side question: why does it have a floor of 1 here?   We don't seem to have that floor at the machine level... maybe we don't need it here either.  IIRC, BFI adds a floor later.  Or maybe this should even be a verifier check for !prof attachments.)

In which case, just rewrite the above to gather the max in the first loop, and calculate this as (excuse my thumbs, I'm on a phone now):

if (Max > UINT32_MAX / numsucc) {
  Scale = ...;
  for (auto &W : ...)
     W = max(1, W / Scale);
}

The point is I think normal integer math should be sufficient and isn't really more complicated. 

>    }
> 
>    // Set the edge weights.
>    for (unsigned I = 0, E = Weights.size(); I != E; ++I)
>      setEdgeWeight(BB, I, Weights[I]);
> 
> The way I've coded this has a small functional change -- instead of
> scaling to `UINT32_MAX/NumSuccessors` I just check that the sum fits in
> `UINT32_MAX` -- but I think it's a strict improvement.
> 
>>   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
>>   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
>> Index: test/Analysis/BranchProbabilityInfo/pr22718.ll
>> ===================================================================
>> --- /dev/null
>> +++ test/Analysis/BranchProbabilityInfo/pr22718.ll
>> @@ -0,0 +1,84 @@
>> +; RUN: opt < %s -analyze -branch-prob | FileCheck %s
>> +
>> +; In this test, the else clause is taken about 90% of the time. This was not
>> +; reflected in the probability computation because the weight is larger than
>> +; the branch weight cap (about 2 billion).
>> +;
>> +; CHECK: edge for.body -> if.then probability is 238603438 / 2386087085 = 9.9
>> +; CHECK: edge for.body -> if.else probability is 2147483647 / 2386087085 = 90.0
>> +
>> + at y = common global i64 0, align 8
>> + at x = common global i64 0, align 8
>> + at .str = private unnamed_addr constant [17 x i8] c"x = %lu\0Ay = %lu\0A\00", align 1
>> +
>> +; Function Attrs: inlinehint nounwind uwtable
>> +define i32 @main() #0 {
>> +entry:
>> +  %retval = alloca i32, align 4
>> +  %i = alloca i64, align 8
>> +  store i32 0, i32* %retval
>> +  store i64 0, i64* @y, align 8
>> +  store i64 0, i64* @x, align 8
>> +  call void @srand(i32 422304) #3
>> +  store i64 0, i64* %i, align 8
>> +  br label %for.cond
>> +
>> +for.cond:                                         ; preds = %for.inc, %entry
>> +  %0 = load i64, i64* %i, align 8
>> +  %cmp = icmp ult i64 %0, 13000000000
>> +  br i1 %cmp, label %for.body, label %for.end, !prof !1
>> +
>> +for.body:                                         ; preds = %for.cond
>> +  %call = call i32 @rand() #3
>> +  %conv = sitofp i32 %call to double
>> +  %mul = fmul double %conv, 1.000000e+02
>> +  %div = fdiv double %mul, 0x41E0000000000000
>> +  %cmp1 = fcmp ogt double %div, 9.000000e+01
>> +  br i1 %cmp1, label %if.then, label %if.else, !prof !2
>> +
>> +if.then:                                          ; preds = %for.body
>> +  %1 = load i64, i64* @x, align 8
>> +  %inc = add i64 %1, 1
>> +  store i64 %inc, i64* @x, align 8
>> +  br label %if.end
>> +
>> +if.else:                                          ; preds = %for.body
>> +  %2 = load i64, i64* @y, align 8
>> +  %inc3 = add i64 %2, 1
>> +  store i64 %inc3, i64* @y, align 8
>> +  br label %if.end
>> +
>> +if.end:                                           ; preds = %if.else, %if.then
>> +  br label %for.inc
>> +
>> +for.inc:                                          ; preds = %if.end
>> +  %3 = load i64, i64* %i, align 8
>> +  %inc4 = add i64 %3, 1
>> +  store i64 %inc4, i64* %i, align 8
>> +  br label %for.cond
>> +
>> +for.end:                                          ; preds = %for.cond
>> +  %4 = load i64, i64* @x, align 8
>> +  %5 = load i64, i64* @y, align 8
>> +  %call5 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([17 x i8], [17 x i8]* @.str, i32 0, i32 0), i64 %4, i64 %5)
>> +  ret i32 0
>> +}
>> +
>> +; Function Attrs: nounwind
>> +declare void @srand(i32) #1
>> +
>> +; Function Attrs: nounwind
>> +declare i32 @rand() #1
>> +
>> +declare i32 @printf(i8*, ...) #2
>> +
>> +attributes #0 = { inlinehint nounwind uwtable "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+sse,+sse2" "unsafe-fp-math"="false" "use-soft-float"="false" }
>> +attributes #1 = { nounwind "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+sse,+sse2" "unsafe-fp-math"="false" "use-soft-float"="false" }
>> +attributes #2 = { "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+sse,+sse2" "unsafe-fp-math"="false" "use-soft-float"="false" }
>> +attributes #3 = { nounwind }
>> +
>> +!llvm.ident = !{!0}
>> +
>> +!0 = !{!"clang version 3.7.0 (trunk 236218) (llvm/trunk 236235)"}
>> +!1 = !{!"branch_weights", i32 -1044967295, i32 1}
>> +!2 = !{!"branch_weights", i32 433323762, i32 -394957723}
>> Index: test/CodeGen/X86/MachineBranchProb.ll
>> ===================================================================
>> --- test/CodeGen/X86/MachineBranchProb.ll
>> +++ test/CodeGen/X86/MachineBranchProb.ll
>> @@ -18,9 +18,9 @@
>>   %or.cond = or i1 %tobool, %cmp4
>>   br i1 %or.cond, label %for.inc20, label %for.inc, !prof !0
>> ; CHECK: BB#1: derived from LLVM BB %for.cond2
>> -; CHECK: Successors according to CFG: BB#3(56008718) BB#4(2203492365)
>> +; CHECK: Successors according to CFG: BB#3(33787703) BB#4(2181271350)
>> ; CHECK: BB#4: derived from LLVM BB %for.cond2
>> -; CHECK: Successors according to CFG: BB#3(112017436) BB#2(4294967294)
>> +; CHECK: Successors according to CFG: BB#3(67575407) BB#2(4294967294)
>> 
>> for.inc:                                          ; preds = %for.cond2
>>   %shl = shl i32 %bit.0, 1
> 




More information about the llvm-commits mailing list