[PATCH] Fix information loss in branch probability computation.

Duncan P. N. Exon Smith dexonsmith at apple.com
Fri May 1 19:23:04 PDT 2015


> On 2015 May 1, at 13:21, Diego Novillo <dnovillo at google.com> wrote:
> 
> Hi dexonsmith,
> 
> This addresses PR 22718. When branch weights are too large, they were
> being clamped to the range [1, MaxWeightForBB]. But this clamping is
> only applied to edges that go outside the range, so it distorts the
> relative branch probabilities.
> 
> This patch changes the weight calculation to scale every branch so the
> relative probabilities are preserved.
> 
> The patch fixes an existing test that had slightly wrong branch
> probabilities due to the previous clamping. It now gets branch weights
> scaled accordingly.
> 
> http://reviews.llvm.org/D9442
> 
> Files:
>  lib/Analysis/BranchProbabilityInfo.cpp
>  test/Analysis/BranchProbabilityInfo/pr22718.ll
>  test/CodeGen/X86/MachineBranchProb.ll
> 
> EMAIL PREFERENCES
>  http://reviews.llvm.org/settings/panel/emailpreferences/
> <D9442.24827.patch>

> Index: lib/Analysis/BranchProbabilityInfo.cpp
> ===================================================================
> --- lib/Analysis/BranchProbabilityInfo.cpp
> +++ lib/Analysis/BranchProbabilityInfo.cpp
> @@ -21,6 +21,7 @@
>  #include "llvm/IR/LLVMContext.h"
>  #include "llvm/IR/Metadata.h"
>  #include "llvm/Support/Debug.h"
> +#include "llvm/Support/ScaledNumber.h"
>  #include "llvm/Support/raw_ostream.h"
>  
>  using namespace llvm;
> @@ -119,6 +120,26 @@
>    return UINT32_MAX / BB->getTerminator()->getNumSuccessors();
>  }
>  
> +/// \brief Determine the maximum branch weight going out of a block.
> +///
> +/// This returns the maximum branch weight annotation found in the
> +/// given MD_prof annotation \p Weights.
> +static uint32_t getMaxUnscaledWeightFor(const MDNode *WeightsNode) {
> +  uint32_t Max = 0;
> +  for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
> +    ConstantInt *Weight =
> +        mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
> +    if (Weight) {
> +      assert(Weight->getValue().getActiveBits() <= 32 &&
> +             "Too many bits for uint32_t");
> +      uint32_t Val = static_cast<uint32_t>(Weight->getZExtValue());
> +      if (Val > Max)
> +        Max = Val;
> +    }
> +  }
> +  assert(Max > 0 && "Maximum branch weight should not be zero.");
> +  return Max;
> +}
>  
>  /// \brief Calculate edge weights for successors lead to unreachable.
>  ///
> @@ -191,18 +212,25 @@
>      return false;
>  
>    // Build up the final weights that will be used in a temporary buffer, but
> -  // don't add them until all weights are present. Each weight value is clamped
> -  // to [1, getMaxWeightFor(BB)].
> +  // don't add them until all weights are present. Each weight value is scaled
> +  // to the range [1, getMaxWeightFor(BB)].
>    uint32_t WeightLimit = getMaxWeightFor(BB);
> +  uint32_t MaxUnscaledWeight = getMaxUnscaledWeightFor(WeightsNode);
> +  typedef ScaledNumber<uint32_t> Scaled32;
> +  Scaled32 ScalingFactor =
> +      (MaxUnscaledWeight > WeightLimit)
> +          ? Scaled32(WeightLimit, 0) / Scaled32(MaxUnscaledWeight, 0)
> +          : Scaled32::getOne();
>    SmallVector<uint32_t, 2> Weights;
>    Weights.reserve(TI->getNumSuccessors());
>    for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
>      ConstantInt *Weight =
>          mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
>      if (!Weight)
>        return false;
> -    Weights.push_back(
> -      std::max<uint32_t>(1, Weight->getLimitedValue(WeightLimit)));
> +    uint32_t ScaledWeight =
> +        (Scaled32(Weight->getZExtValue(), 0) * ScalingFactor).toInt<uint32_t>();
> +    Weights.push_back(std::max<uint32_t>(1, ScaledWeight));
>    }

Hmm.  Now I see why you thought `ScaledNumber` should maybe have some
short circults.

I feel like `ScaledNumber` is too heavy-weight for this.  In BFI the
global scale of the numbers is hard to predict, but here we're just
dealing with a single basic block.  All the numbers are small.

I think this could be fixed more simply by doing something similar to
MachineBranchProbabilityInfo::getSumForBlock():

    // Check that we have a sane number of successors.
    assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");

    // Copy out the weights, accumulating a 64-bit sum.
    uint64_t Sum = 0;
    SmallVector<uint32_t, 2> Weights;
    Weights.reserve(TI->getNumSuccessors());
    for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) {
      ConstantInt *Weight =
          mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
      if (!Weight)
        return false;
      assert(Weight->getValue().getActiveBits() <= 32 &&
             "Too many bits for uint32_t");
      Weights.push_back(std::max(1, Weight->getZExtValue()));
      Sum += Weights.back();
    }
    assert(Weights.size() == TI->getNumSuccessors() && "Checked above");

    // Scale the weights down if the sum exceeds UINT32_MAX.
    if (Sum > UINT32_MAX) {
      uint64_t Scale = Sum / UINT32_MAX + 1;
      Sum = 0;
      for (auto &W : Weights)
        Sum += W = std::max(1, W / Scale);
      assert(Sum <= UINT32_MAX && "Expected sum to get scaled down");
    }

    // Set the edge weights.
    for (unsigned I = 0, E = Weights.size(); I != E; ++I)
      setEdgeWeight(BB, I, Weights[I]);

The way I've coded this has a small functional change -- instead of
scaling to `UINT32_MAX/NumSuccessors` I just check that the sum fits in
`UINT32_MAX` -- but I think it's a strict improvement.

>    assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
>    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
> Index: test/Analysis/BranchProbabilityInfo/pr22718.ll
> ===================================================================
> --- /dev/null
> +++ test/Analysis/BranchProbabilityInfo/pr22718.ll
> @@ -0,0 +1,84 @@
> +; RUN: opt < %s -analyze -branch-prob | FileCheck %s
> +
> +; In this test, the else clause is taken about 90% of the time. This was not
> +; reflected in the probability computation because the weight is larger than
> +; the branch weight cap (about 2 billion).
> +;
> +; CHECK: edge for.body -> if.then probability is 238603438 / 2386087085 = 9.9
> +; CHECK: edge for.body -> if.else probability is 2147483647 / 2386087085 = 90.0
> +
> + at y = common global i64 0, align 8
> + at x = common global i64 0, align 8
> + at .str = private unnamed_addr constant [17 x i8] c"x = %lu\0Ay = %lu\0A\00", align 1
> +
> +; Function Attrs: inlinehint nounwind uwtable
> +define i32 @main() #0 {
> +entry:
> +  %retval = alloca i32, align 4
> +  %i = alloca i64, align 8
> +  store i32 0, i32* %retval
> +  store i64 0, i64* @y, align 8
> +  store i64 0, i64* @x, align 8
> +  call void @srand(i32 422304) #3
> +  store i64 0, i64* %i, align 8
> +  br label %for.cond
> +
> +for.cond:                                         ; preds = %for.inc, %entry
> +  %0 = load i64, i64* %i, align 8
> +  %cmp = icmp ult i64 %0, 13000000000
> +  br i1 %cmp, label %for.body, label %for.end, !prof !1
> +
> +for.body:                                         ; preds = %for.cond
> +  %call = call i32 @rand() #3
> +  %conv = sitofp i32 %call to double
> +  %mul = fmul double %conv, 1.000000e+02
> +  %div = fdiv double %mul, 0x41E0000000000000
> +  %cmp1 = fcmp ogt double %div, 9.000000e+01
> +  br i1 %cmp1, label %if.then, label %if.else, !prof !2
> +
> +if.then:                                          ; preds = %for.body
> +  %1 = load i64, i64* @x, align 8
> +  %inc = add i64 %1, 1
> +  store i64 %inc, i64* @x, align 8
> +  br label %if.end
> +
> +if.else:                                          ; preds = %for.body
> +  %2 = load i64, i64* @y, align 8
> +  %inc3 = add i64 %2, 1
> +  store i64 %inc3, i64* @y, align 8
> +  br label %if.end
> +
> +if.end:                                           ; preds = %if.else, %if.then
> +  br label %for.inc
> +
> +for.inc:                                          ; preds = %if.end
> +  %3 = load i64, i64* %i, align 8
> +  %inc4 = add i64 %3, 1
> +  store i64 %inc4, i64* %i, align 8
> +  br label %for.cond
> +
> +for.end:                                          ; preds = %for.cond
> +  %4 = load i64, i64* @x, align 8
> +  %5 = load i64, i64* @y, align 8
> +  %call5 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([17 x i8], [17 x i8]* @.str, i32 0, i32 0), i64 %4, i64 %5)
> +  ret i32 0
> +}
> +
> +; Function Attrs: nounwind
> +declare void @srand(i32) #1
> +
> +; Function Attrs: nounwind
> +declare i32 @rand() #1
> +
> +declare i32 @printf(i8*, ...) #2
> +
> +attributes #0 = { inlinehint nounwind uwtable "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+sse,+sse2" "unsafe-fp-math"="false" "use-soft-float"="false" }
> +attributes #1 = { nounwind "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+sse,+sse2" "unsafe-fp-math"="false" "use-soft-float"="false" }
> +attributes #2 = { "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+sse,+sse2" "unsafe-fp-math"="false" "use-soft-float"="false" }
> +attributes #3 = { nounwind }
> +
> +!llvm.ident = !{!0}
> +
> +!0 = !{!"clang version 3.7.0 (trunk 236218) (llvm/trunk 236235)"}
> +!1 = !{!"branch_weights", i32 -1044967295, i32 1}
> +!2 = !{!"branch_weights", i32 433323762, i32 -394957723}
> Index: test/CodeGen/X86/MachineBranchProb.ll
> ===================================================================
> --- test/CodeGen/X86/MachineBranchProb.ll
> +++ test/CodeGen/X86/MachineBranchProb.ll
> @@ -18,9 +18,9 @@
>    %or.cond = or i1 %tobool, %cmp4
>    br i1 %or.cond, label %for.inc20, label %for.inc, !prof !0
>  ; CHECK: BB#1: derived from LLVM BB %for.cond2
> -; CHECK: Successors according to CFG: BB#3(56008718) BB#4(2203492365)
> +; CHECK: Successors according to CFG: BB#3(33787703) BB#4(2181271350)
>  ; CHECK: BB#4: derived from LLVM BB %for.cond2
> -; CHECK: Successors according to CFG: BB#3(112017436) BB#2(4294967294)
> +; CHECK: Successors according to CFG: BB#3(67575407) BB#2(4294967294)
>  
>  for.inc:                                          ; preds = %for.cond2
>    %shl = shl i32 %bit.0, 1
> 





More information about the llvm-commits mailing list