[llvm-commits] patch: go crazy, compute bits for an entire add instruction
Jay Foad
jay.foad at gmail.com
Thu Jul 21 03:00:18 PDT 2011
On 21 July 2011 09:06, Nick Lewycky <nicholas at mxc.ca> wrote:
> Jay, would you be willing to review this updated patch, now for add+sub?
Now the add/sub bits. (As a unified diff fan I hate to say it, but I
think a context diff would be much easier to read for this bit!)
> Index: lib/Analysis/ValueTracking.cpp
> ===================================================================
> --- lib/Analysis/ValueTracking.cpp (revision 135567)
> +++ lib/Analysis/ValueTracking.cpp (working copy)
> @@ -377,86 +377,66 @@
> return;
> }
> break;
> + case Instruction::Add: // fall-through
> case Instruction::Sub: {
> - if (ConstantInt *CLHS = dyn_cast<ConstantInt>(I->getOperand(0))) {
> - // We know that the top bits of C-X are clear if X contains less bits
> - // than C (i.e. no wrap-around can happen). For example, 20-X is
> - // positive if we can prove that X is >= 0 and < 16.
> - if (!CLHS->getValue().isNegative()) {
> - unsigned NLZ = (CLHS->getValue()+1).countLeadingZeros();
> - // NLZ can't be BitWidth with no sign bit
> - APInt MaskV = APInt::getHighBitsSet(BitWidth, NLZ+1);
> - ComputeMaskedBits(I->getOperand(1), MaskV, KnownZero2, KnownOne2,
> - TD, Depth+1);
> -
> - // If all of the MaskV bits are known to be zero, then we know the
> - // output top bits are zero, because we now know that the output is
> - // from [0-C].
> - if ((KnownZero2 & MaskV) == MaskV) {
> - unsigned NLZ2 = CLHS->getValue().countLeadingZeros();
> - // Top bits known zero.
> - KnownZero = APInt::getHighBitsSet(BitWidth, NLZ2) & Mask;
> - }
> - }
> - }
> - }
> - // fall through
> - case Instruction::Add: {
> - // If one of the operands has trailing zeros, then the bits that the
> - // other operand has in those bit positions will be preserved in the
> - // result. For an add, this works with either operand. For a subtract,
> - // this only works if the known zeros are in the right operand.
> + APInt Mask2 = APInt::getAllOnesValue(BitWidth);
> APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
> - APInt Mask2 = APInt::getLowBitsSet(BitWidth,
> - BitWidth - Mask.countLeadingZeros());
> ComputeMaskedBits(I->getOperand(0), Mask2, LHSKnownZero, LHSKnownOne, TD,
> Depth+1);
> - assert((LHSKnownZero & LHSKnownOne) == 0 &&
> - "Bits known to be one AND zero?");
> - unsigned LHSKnownZeroOut = LHSKnownZero.countTrailingOnes();
> + if (LHSKnownZero.isMinValue() && LHSKnownOne.isMinValue())
> + return;
If I ruled the world, ComputeMaskedBits(UndefValue) would return with
both KnownZero and KnownOne set to an all-ones value. But I don't, so
it probably doesn't.
>
> - ComputeMaskedBits(I->getOperand(1), Mask2, KnownZero2, KnownOne2, TD,
> + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0);
> + ComputeMaskedBits(I->getOperand(1), Mask2, RHSKnownZero, RHSKnownOne, TD,
> Depth+1);
> - assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");
> - unsigned RHSKnownZeroOut = KnownZero2.countTrailingOnes();
> + if (RHSKnownZero.isMinValue() && RHSKnownOne.isMinValue())
> + return;
>
> - // Determine which operand has more trailing zeros, and use that
> - // many bits from the other operand.
> - if (LHSKnownZeroOut > RHSKnownZeroOut) {
> - if (I->getOpcode() == Instruction::Add) {
> - APInt Mask = APInt::getLowBitsSet(BitWidth, LHSKnownZeroOut);
> - KnownZero |= KnownZero2 & Mask;
> - KnownOne |= KnownOne2 & Mask;
> - } else {
> - // If the known zeros are in the left operand for a subtract,
> - // fall back to the minimum known zeros in both operands.
> - KnownZero |= APInt::getLowBitsSet(BitWidth,
> - std::min(LHSKnownZeroOut,
> - RHSKnownZeroOut));
> - }
> - } else if (RHSKnownZeroOut >= LHSKnownZeroOut) {
> - APInt Mask = APInt::getLowBitsSet(BitWidth, RHSKnownZeroOut);
> - KnownZero |= LHSKnownZero & Mask;
> - KnownOne |= LHSKnownOne & Mask;
> + // Calculate the sum/difference as if the unknown bits were all zeros, and
> + // another as if the unknown bits were all ones.
> + APInt OpWithZeros;
> + APInt OpWithOnes;
> + if (I->getOpcode() == Instruction::Add) {
> + OpWithZeros = LHSKnownOne + RHSKnownOne;
> + OpWithOnes = ~LHSKnownZero + ~RHSKnownZero;
> + } else {
> + OpWithZeros = LHSKnownOne - RHSKnownOne;
> + OpWithOnes = ~LHSKnownZero - ~RHSKnownZero;
> }
I think this is wrong for subtract. An example with 2-bit values,
using x for unknown bits:
LHS = 1x // i.e. 10 or 11, two or three
RHS = 0x // i.e. 00 or 01, zero or one
// therefore LHS - RHS is one, two or three
your OpWithZeroes = 10
your OpWithOnes = 10
your CarryMask = 11
you would conclude that the result is 1x, which is wrong
A better attempt would be:
OpWithMinimalBorrows = ~LHSKnownZero - RHSKnownOne;
OpWithMaximalBorrows = LHSKnownOne - ~RHSKnownZero;
... but I'd have to think about this some more to be 100% convinced
that it is correct.
(An alternative way of coding this is to write a general
SolveAdd(Value *LHS, Value *RHS, unsigned CarryIn), and then handle
ADD with SolveAdd(LHR, RHS, 0), and SUB with SolveAdd(LHS, ~RHS, 1).)
>
> + // At a bit position where OpWithZeros and OpWithOne agree, the carry
> + // value is the same regardless of the unknown bits.
> + APInt CarryMask = ~(OpWithZeros ^ OpWithOnes);
> +
> + // We can only know the result of the subtraction when we know the values
> + // for the left, right and carry inputs.
> + APInt KnownMask = (LHSKnownZero | LHSKnownOne) &
> + (RHSKnownZero | RHSKnownOne) &
> + CarryMask &
> + Mask;
> +
> + // At this stage, for every bit position where KnownMask is true, exactly
> + // one of either OpWithZeros or OpWithOnes is set.
> + KnownZero = ~OpWithZeros & KnownMask;
> + KnownOne = OpWithOnes & KnownMask;
The comment is wrong. For every bit position where KnownMask is 1,
OpWithZeros == OpWithOnes.
> +
> // Are we still trying to solve for the sign bit?
> - if (Mask.isNegative() && !KnownZero.isNegative() && !KnownOne.isNegative()){
> + if (Mask.isNegative() && !KnownMask.isNegative()) {
> OverflowingBinaryOperator *OBO = cast<OverflowingBinaryOperator>(I);
> if (OBO->hasNoSignedWrap()) {
> if (I->getOpcode() == Instruction::Add) {
> // Adding two positive numbers can't wrap into negative
> - if (LHSKnownZero.isNegative() && KnownZero2.isNegative())
> + if (LHSKnownZero.isNegative() && RHSKnownZero.isNegative())
> KnownZero |= APInt::getSignBit(BitWidth);
> // and adding two negative numbers can't wrap into positive.
> - else if (LHSKnownOne.isNegative() && KnownOne2.isNegative())
> + else if (LHSKnownOne.isNegative() && RHSKnownOne.isNegative())
> KnownOne |= APInt::getSignBit(BitWidth);
> } else {
> // Subtracting a negative number from a positive one can't wrap
> - if (LHSKnownZero.isNegative() && KnownOne2.isNegative())
> + if (LHSKnownZero.isNegative() && RHSKnownOne.isNegative())
> KnownZero |= APInt::getSignBit(BitWidth);
> // neither can subtracting a positive number from a negative one.
> - else if (LHSKnownOne.isNegative() && KnownZero2.isNegative())
> + else if (LHSKnownOne.isNegative() && RHSKnownZero.isNegative())
> KnownOne |= APInt::getSignBit(BitWidth);
> }
> }
The rest looks good to me.
Jay.
More information about the llvm-commits
mailing list