[PATCH] Fix information loss in branch probability computation.
Duncan Exon Smith
dexonsmith at apple.com
Fri May 1 19:42:45 PDT 2015
> On May 1, 2015, at 7:23 PM, Duncan P. N. Exon Smith <dexonsmith at apple.com> wrote:
>
>
>> On 2015 May 1, at 13:21, Diego Novillo <dnovillo at google.com> wrote:
>>
>> Hi dexonsmith,
>>
>> This addresses PR 22718. When branch weights are too large, they were
>> being clamped to the range [1, MaxWeightForBB]. But this clamping is
>> only applied to edges that go outside the range, so it distorts the
>> relative branch probabilities.
>>
>> This patch changes the weight calculation to scale every branch so the
>> relative probabilities are preserved.
>>
>> The patch fixes an existing test that had slightly wrong branch
>> probabilities due to the previous clamping. It now gets branch weights
>> scaled accordingly.
>>
>> http://reviews.llvm.org/D9442
>>
>> Files:
>> lib/Analysis/BranchProbabilityInfo.cpp
>> test/Analysis/BranchProbabilityInfo/pr22718.ll
>> test/CodeGen/X86/MachineBranchProb.ll
>>
>> EMAIL PREFERENCES
>> http://reviews.llvm.org/settings/panel/emailpreferences/
>> <D9442.24827.patch>
>
>> Index: lib/Analysis/BranchProbabilityInfo.cpp
>> ===================================================================
>> --- lib/Analysis/BranchProbabilityInfo.cpp
>> +++ lib/Analysis/BranchProbabilityInfo.cpp
>> @@ -21,6 +21,7 @@
>> #include "llvm/IR/LLVMContext.h"
>> #include "llvm/IR/Metadata.h"
>> #include "llvm/Support/Debug.h"
>> +#include "llvm/Support/ScaledNumber.h"
>> #include "llvm/Support/raw_ostream.h"
>>
>> using namespace llvm;
>> @@ -119,6 +120,26 @@
>> return UINT32_MAX / BB->getTerminator()->getNumSuccessors();
>> }
>>
>> +/// \brief Determine the maximum branch weight going out of a block.
>> +///
>> +/// This returns the maximum branch weight annotation found in the
>> +/// given MD_prof annotation \p Weights.
>> +static uint32_t getMaxUnscaledWeightFor(const MDNode *WeightsNode) {
>> + uint32_t Max = 0;
>> + for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
>> + ConstantInt *Weight =
>> + mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
>> + if (Weight) {
>> + assert(Weight->getValue().getActiveBits() <= 32 &&
>> + "Too many bits for uint32_t");
>> + uint32_t Val = static_cast<uint32_t>(Weight->getZExtValue());
>> + if (Val > Max)
>> + Max = Val;
>> + }
>> + }
>> + assert(Max > 0 && "Maximum branch weight should not be zero.");
>> + return Max;
>> +}
>>
>> /// \brief Calculate edge weights for successors lead to unreachable.
>> ///
>> @@ -191,18 +212,25 @@
>> return false;
>>
>> // Build up the final weights that will be used in a temporary buffer, but
>> - // don't add them until all weights are present. Each weight value is clamped
>> - // to [1, getMaxWeightFor(BB)].
>> + // don't add them until all weights are present. Each weight value is scaled
>> + // to the range [1, getMaxWeightFor(BB)].
>> uint32_t WeightLimit = getMaxWeightFor(BB);
>> + uint32_t MaxUnscaledWeight = getMaxUnscaledWeightFor(WeightsNode);
>> + typedef ScaledNumber<uint32_t> Scaled32;
>> + Scaled32 ScalingFactor =
>> + (MaxUnscaledWeight > WeightLimit)
>> + ? Scaled32(WeightLimit, 0) / Scaled32(MaxUnscaledWeight, 0)
>> + : Scaled32::getOne();
>> SmallVector<uint32_t, 2> Weights;
>> Weights.reserve(TI->getNumSuccessors());
>> for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
>> ConstantInt *Weight =
>> mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
>> if (!Weight)
>> return false;
>> - Weights.push_back(
>> - std::max<uint32_t>(1, Weight->getLimitedValue(WeightLimit)));
>> + uint32_t ScaledWeight =
>> + (Scaled32(Weight->getZExtValue(), 0) * ScalingFactor).toInt<uint32_t>();
>> + Weights.push_back(std::max<uint32_t>(1, ScaledWeight));
>> }
>
> Hmm. Now I see why you thought `ScaledNumber` should maybe have some
> short circults.
>
> I feel like `ScaledNumber` is too heavy-weight for this. In BFI the
> global scale of the numbers is hard to predict, but here we're just
> dealing with a single basic block. All the numbers are small.
>
> I think this could be fixed more simply by doing something similar to
> MachineBranchProbabilityInfo::getSumForBlock():
>
> // Check that we have a sane number of successors.
> assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
>
> // Copy out the weights, accumulating a 64-bit sum.
> uint64_t Sum = 0;
> SmallVector<uint32_t, 2> Weights;
> Weights.reserve(TI->getNumSuccessors());
> for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) {
> ConstantInt *Weight =
> mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
> if (!Weight)
> return false;
> assert(Weight->getValue().getActiveBits() <= 32 &&
> "Too many bits for uint32_t");
> Weights.push_back(std::max(1, Weight->getZExtValue()));
> Sum += Weights.back();
> }
> assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
>
> // Scale the weights down if the sum exceeds UINT32_MAX.
> if (Sum > UINT32_MAX) {
> uint64_t Scale = Sum / UINT32_MAX + 1;
> Sum = 0;
> for (auto &W : Weights)
> Sum += W = std::max(1, W / Scale);
> assert(Sum <= UINT32_MAX && "Expected sum to get scaled down");
Actually, maybe this assert could fire because of the floor of 1. (Side question: why does it have a floor of 1 here? We don't seem to have that floor at the machine level... maybe we don't need it here either. IIRC, BFI adds a floor later. Or maybe this should even be a verifier check for !prof attachments.)
In which case, just rewrite the above to gather the max in the first loop, and calculate this as (excuse my thumbs, I'm on a phone now):
if (Max > UINT32_MAX / numsucc) {
Scale = ...;
for (auto &W : ...)
W = max(1, W / Scale);
}
The point is I think normal integer math should be sufficient and isn't really more complicated.
> }
>
> // Set the edge weights.
> for (unsigned I = 0, E = Weights.size(); I != E; ++I)
> setEdgeWeight(BB, I, Weights[I]);
>
> The way I've coded this has a small functional change -- instead of
> scaling to `UINT32_MAX/NumSuccessors` I just check that the sum fits in
> `UINT32_MAX` -- but I think it's a strict improvement.
>
>> assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
>> for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
>> Index: test/Analysis/BranchProbabilityInfo/pr22718.ll
>> ===================================================================
>> --- /dev/null
>> +++ test/Analysis/BranchProbabilityInfo/pr22718.ll
>> @@ -0,0 +1,84 @@
>> +; RUN: opt < %s -analyze -branch-prob | FileCheck %s
>> +
>> +; In this test, the else clause is taken about 90% of the time. This was not
>> +; reflected in the probability computation because the weight is larger than
>> +; the branch weight cap (about 2 billion).
>> +;
>> +; CHECK: edge for.body -> if.then probability is 238603438 / 2386087085 = 9.9
>> +; CHECK: edge for.body -> if.else probability is 2147483647 / 2386087085 = 90.0
>> +
>> + at y = common global i64 0, align 8
>> + at x = common global i64 0, align 8
>> + at .str = private unnamed_addr constant [17 x i8] c"x = %lu\0Ay = %lu\0A\00", align 1
>> +
>> +; Function Attrs: inlinehint nounwind uwtable
>> +define i32 @main() #0 {
>> +entry:
>> + %retval = alloca i32, align 4
>> + %i = alloca i64, align 8
>> + store i32 0, i32* %retval
>> + store i64 0, i64* @y, align 8
>> + store i64 0, i64* @x, align 8
>> + call void @srand(i32 422304) #3
>> + store i64 0, i64* %i, align 8
>> + br label %for.cond
>> +
>> +for.cond: ; preds = %for.inc, %entry
>> + %0 = load i64, i64* %i, align 8
>> + %cmp = icmp ult i64 %0, 13000000000
>> + br i1 %cmp, label %for.body, label %for.end, !prof !1
>> +
>> +for.body: ; preds = %for.cond
>> + %call = call i32 @rand() #3
>> + %conv = sitofp i32 %call to double
>> + %mul = fmul double %conv, 1.000000e+02
>> + %div = fdiv double %mul, 0x41E0000000000000
>> + %cmp1 = fcmp ogt double %div, 9.000000e+01
>> + br i1 %cmp1, label %if.then, label %if.else, !prof !2
>> +
>> +if.then: ; preds = %for.body
>> + %1 = load i64, i64* @x, align 8
>> + %inc = add i64 %1, 1
>> + store i64 %inc, i64* @x, align 8
>> + br label %if.end
>> +
>> +if.else: ; preds = %for.body
>> + %2 = load i64, i64* @y, align 8
>> + %inc3 = add i64 %2, 1
>> + store i64 %inc3, i64* @y, align 8
>> + br label %if.end
>> +
>> +if.end: ; preds = %if.else, %if.then
>> + br label %for.inc
>> +
>> +for.inc: ; preds = %if.end
>> + %3 = load i64, i64* %i, align 8
>> + %inc4 = add i64 %3, 1
>> + store i64 %inc4, i64* %i, align 8
>> + br label %for.cond
>> +
>> +for.end: ; preds = %for.cond
>> + %4 = load i64, i64* @x, align 8
>> + %5 = load i64, i64* @y, align 8
>> + %call5 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([17 x i8], [17 x i8]* @.str, i32 0, i32 0), i64 %4, i64 %5)
>> + ret i32 0
>> +}
>> +
>> +; Function Attrs: nounwind
>> +declare void @srand(i32) #1
>> +
>> +; Function Attrs: nounwind
>> +declare i32 @rand() #1
>> +
>> +declare i32 @printf(i8*, ...) #2
>> +
>> +attributes #0 = { inlinehint nounwind uwtable "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+sse,+sse2" "unsafe-fp-math"="false" "use-soft-float"="false" }
>> +attributes #1 = { nounwind "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+sse,+sse2" "unsafe-fp-math"="false" "use-soft-float"="false" }
>> +attributes #2 = { "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+sse,+sse2" "unsafe-fp-math"="false" "use-soft-float"="false" }
>> +attributes #3 = { nounwind }
>> +
>> +!llvm.ident = !{!0}
>> +
>> +!0 = !{!"clang version 3.7.0 (trunk 236218) (llvm/trunk 236235)"}
>> +!1 = !{!"branch_weights", i32 -1044967295, i32 1}
>> +!2 = !{!"branch_weights", i32 433323762, i32 -394957723}
>> Index: test/CodeGen/X86/MachineBranchProb.ll
>> ===================================================================
>> --- test/CodeGen/X86/MachineBranchProb.ll
>> +++ test/CodeGen/X86/MachineBranchProb.ll
>> @@ -18,9 +18,9 @@
>> %or.cond = or i1 %tobool, %cmp4
>> br i1 %or.cond, label %for.inc20, label %for.inc, !prof !0
>> ; CHECK: BB#1: derived from LLVM BB %for.cond2
>> -; CHECK: Successors according to CFG: BB#3(56008718) BB#4(2203492365)
>> +; CHECK: Successors according to CFG: BB#3(33787703) BB#4(2181271350)
>> ; CHECK: BB#4: derived from LLVM BB %for.cond2
>> -; CHECK: Successors according to CFG: BB#3(112017436) BB#2(4294967294)
>> +; CHECK: Successors according to CFG: BB#3(67575407) BB#2(4294967294)
>>
>> for.inc: ; preds = %for.cond2
>> %shl = shl i32 %bit.0, 1
>
More information about the llvm-commits
mailing list