[llvm] KnownBits: generalize high-bits of mul to overflows (PR #114211)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 30 04:37:27 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-support
Author: Ramkumar Ramachandra (artagnon)
<details>
<summary>Changes</summary>
Make the non-overflow case of KnownBits::mul optimal, and smoothly generalize it to the case when overflow occurs by relying on min-product in addition to max-product, noting that it cannot possibly be optimal unless we also look at the bits in between min-product and max-product.
-- 8< --
Based on #<!-- -->113051.
---
Full diff: https://github.com/llvm/llvm-project/pull/114211.diff
3 Files Affected:
- (modified) llvm/lib/Support/KnownBits.cpp (+88-13)
- (added) llvm/test/Analysis/ValueTracking/knownbits-mul.ll (+143)
- (modified) llvm/unittests/Support/KnownBitsTest.cpp (+148-1)
``````````diff
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 89668af378070b..c2d7c776725088 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -796,19 +796,93 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
assert((!NoUndefSelfMultiply || LHS == RHS) &&
"Self multiplication knownbits mismatch");
- // Compute the high known-0 bits by multiplying the unsigned max of each side.
- // Conservatively, M active bits * N active bits results in M + N bits in the
- // result. But if we know a value is a power-of-2 for example, then this
- // computes one more leading zero.
- // TODO: This could be generalized to number of sign bits (negative numbers).
- APInt UMaxLHS = LHS.getMaxValue();
- APInt UMaxRHS = RHS.getMaxValue();
-
- // For leading zeros in the result to be valid, the unsigned max product must
- // fit in the bitwidth (it must not overflow).
+ // Compute the high known-0 or known-1 bits by multiplying the min and max of
+ // each side.
+ APInt MaxLHS = LHS.isNegative() ? LHS.getMinValue().abs() : LHS.getMaxValue(),
+ MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue(),
+ MinLHS = LHS.isNegative() ? LHS.getMaxValue().abs() : LHS.getMinValue(),
+ MinRHS = RHS.isNegative() ? RHS.getMaxValue().abs() : RHS.getMinValue();
+
+ // If MaxProduct doesn't overflow, it implies that MinProduct also won't
+ // overflow. However, if MaxProduct overflows, there is no guarantee on the
+ // MinProduct overflowing.
bool HasOverflow;
- APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
- unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
+ APInt MaxProduct = MaxLHS.umul_ov(MaxRHS, HasOverflow),
+ MinProduct = MinLHS * MinRHS;
+
+ bool OpsSignMatch = LHS.isNegative() == RHS.isNegative();
+ if (!OpsSignMatch) {
+ // The unsigned-multiplication wrapped MinProduct and MaxProduct can be
+ // negated to turn them into the corresponding signed-multiplication
+ // wrapped values.
+ MinProduct.negate();
+ MaxProduct.negate();
+ }
+
+ // Unless both MinProduct and MaxProduct are the same sign, there won't be any
+ // leading zeros or ones in the result.
+ unsigned LeadZ = 0, LeadO = 0;
+ if (MinProduct.isNegative() == MaxProduct.isNegative()) {
+ APInt LHSUnknown = (~LHS.Zero & ~LHS.One),
+ RHSUnknown = (~RHS.Zero & ~RHS.One);
+
+ // A product of M active bits * N active bits results in M + N bits in the
+ // result. If either of the operands is a power of two, the result has one
+ // less active bit.
+ auto ProdActiveBits = [](const APInt &A, const APInt &B) -> unsigned {
+ if (A.isZero() || B.isZero())
+ return 0;
+ return A.getActiveBits() + B.getActiveBits() -
+ (A.isPowerOf2() || B.isPowerOf2());
+ };
+
+ // We want to compute the number of active bits in the difference between
+ // the non-wrapped max product and non-wrapped min product, but we want to
+ // avoid camputing the non-wrapped max/min product.
+ unsigned ActiveBitsInDiff;
+ if (MinLHS.isZero() && MinRHS.isZero())
+ ActiveBitsInDiff = ProdActiveBits(LHSUnknown, RHSUnknown);
+ else
+ ActiveBitsInDiff =
+ ProdActiveBits(MinLHS.isZero() ? LHSUnknown : MinLHS, RHSUnknown) +
+ ProdActiveBits(MinRHS.isZero() ? RHSUnknown : MinRHS, LHSUnknown);
+
+ // We uniformly handle the case where there is no max-overflow, in which
+ // case the high zeros and ones are computed optimally, and where there is,
+ // but the result shifts at most by BitWidth, in which case the high zeros
+ // and ones are not computed optimally.
+ if (!HasOverflow || ActiveBitsInDiff <= BitWidth) {
+ // Checks that A.ugt(B), excluding the degenerate case where A is all-ones
+ // and B is zero.
+ auto UgtCheckCorner = [](const APInt &A, const APInt &B) {
+ return (!A.isAllOnes() || !B.isZero()) && A.ugt(B);
+ };
+
+ // If we're shifting by BitWidth, MaxProduct and MinProduct are swapped.
+ bool MinMaxSwap = ActiveBitsInDiff == BitWidth;
+ if (MinMaxSwap)
+ std::swap(MaxProduct, MinProduct);
+
+ if (OpsSignMatch != MinMaxSwap) {
+ // Normally, this is the case for when the signs of LHS and RHS match,
+ // and the else branch is for when the signs mismatch. However, if min
+ // and max were swapped, we need to invert these cases.
+ if (UgtCheckCorner(MaxProduct, MinProduct)) {
+ // Normally, when the signs of LHS and RHS match, we can safely set
+ // leading zeros of the result. However, if both MaxProduct and
+ // MinProduct are negative, we can also set the leading ones.
+ LeadZ = MaxProduct.countLeadingZeros();
+ LeadO = (MaxProduct & MinProduct).countLeadingOnes();
+ }
+ } else if (UgtCheckCorner(MinProduct, MaxProduct)) {
+ // Normally, when the signs of LHS and RHS mismatch, we can safely set
+ // leading ones of the result. However, if both MaxProduct and
+ // MinProduct are non-negative, we can also set the leading zeros.
+ LeadO = MaxProduct.countLeadingOnes();
+ LeadZ = (MaxProduct | MinProduct).countLeadingZeros();
+ }
+ }
+ }
// The result of the bottom bits of an integer multiply can be
// inferred by looking at the bottom bits of both operands and
@@ -873,8 +947,9 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
KnownBits Res(BitWidth);
Res.Zero.setHighBits(LeadZ);
+ Res.One.setHighBits(LeadO);
Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
- Res.One = BottomKnown.getLoBits(ResultBitsKnown);
+ Res.One |= BottomKnown.getLoBits(ResultBitsKnown);
// If we're self-multiplying then bit[1] is guaranteed to be zero.
if (NoUndefSelfMultiply && BitWidth > 1) {
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-mul.ll b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
new file mode 100644
index 00000000000000..37526c67f0d9e1
--- /dev/null
+++ b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
@@ -0,0 +1,143 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @mul_low_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 0
+;
+ %x = and i8 %xx, 2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 6
+ ret i8 %r
+}
+
+define i8 @mul_low_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 0
+;
+ %x = or i8 %xx, -2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 2
+ ret i8 %r
+}
+
+define i8 @mul_low_bits_partially_known(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_partially_known(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 2
+; CHECK-NEXT: [[MUL:%.*]] = sub nsw i8 0, [[Y]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 2
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or i8 %xx, -4
+ %x.notsmin = or i8 %x, 3
+ %y = or i8 %yy, -2
+ %mul = mul i8 %x.notsmin, %y
+ %r = and i8 %mul, 6
+ ret i8 %r
+}
+
+define i8 @mul_low_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], 4
+; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 6
+; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 6
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or i8 %xx, -4
+ %y = or i8 %yy, -2
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 6
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 0
+;
+ %x = and i8 %xx, 2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 16
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 -16
+;
+ %x = or i8 %xx, -2
+ %y = and i8 %yy, 4
+ %y.nonzero = or i8 %y, 1
+ %mul = mul i8 %x, %y.nonzero
+ %r = and i8 %mul, -16
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_know3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 0
+;
+ %x = or i8 %xx, -4
+ %y = or i8 %yy, -2
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, -16
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[X:%.*]] = and i8 [[XX]], 2
+; CHECK-NEXT: [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i8 [[X]], [[Y]]
+; CHECK-NEXT: ret i8 [[MUL]]
+;
+ %x = and i8 %xx, 2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 8
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], -2
+; CHECK-NEXT: [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT: [[MUL:%.*]] = mul nsw i8 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], -16
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or i8 %xx, -2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, -16
+ ret i8 %r
+}
+
+; TODO: This can be reduced to zero.
+define i8 @mul_high_bits_unknown3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], 28
+; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 30
+; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 16
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or i8 %xx, -4
+ %y = or i8 %yy, -2
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 16
+ ret i8 %r
+}
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b16368de176481..2be2e1d093315c 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -815,7 +815,7 @@ TEST(KnownBitsTest, ConcatBits) {
}
}
-TEST(KnownBitsTest, MulExhaustive) {
+TEST(KnownBitsTest, MulLowBitsExhaustive) {
for (unsigned Bits : {1, 4}) {
ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
@@ -849,4 +849,151 @@ TEST(KnownBitsTest, MulExhaustive) {
}
}
+TEST(KnownBitsTest, MulHighBitsNoOverflow) {
+ for (unsigned Bits : {1, 4}) {
+ ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+ ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+ KnownBits Computed = KnownBits::mul(Known1, Known2);
+ KnownBits Exact(Bits), WideExact(2 * Bits);
+ Exact.Zero.setAllBits();
+ Exact.One.setAllBits();
+
+ bool HasOverflow;
+ ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+ ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+ // The final value of HasOverflow corresponds to the multiplication
+ // in the last iteration, which is the max product.
+ APInt Res = N1.umul_ov(N2, HasOverflow);
+ Exact.One &= Res;
+ Exact.Zero &= ~Res;
+ });
+ });
+
+ if (!Exact.hasConflict() && !HasOverflow) {
+ // Check that leading zeros and leading ones are optimal in the
+ // result, provided there is no overflow.
+ APInt ZerosMask =
+ APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+ OnesMask =
+ APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+ KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+ KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+ ExactZeros.Zero.setAllBits();
+ ExactZeros.One.setAllBits();
+ ExactOnes.Zero.setAllBits();
+ ExactOnes.One.setAllBits();
+
+ ExactZeros.Zero = Exact.Zero & ZerosMask;
+ ExactZeros.One = Exact.One & ZerosMask;
+ ComputedZeros.Zero = Computed.Zero & ZerosMask;
+ ComputedZeros.One = Computed.One & ZerosMask;
+ EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros,
+ {Known1, Known2},
+ /*CheckOptimality=*/true));
+
+ ExactOnes.Zero = Exact.Zero & OnesMask;
+ ExactOnes.One = Exact.One & OnesMask;
+ ComputedOnes.Zero = Computed.Zero & OnesMask;
+ ComputedOnes.One = Computed.One & OnesMask;
+ EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes,
+ {Known1, Known2},
+ /*CheckOptimality=*/true));
+ }
+ });
+ });
+ }
+}
+
+TEST(KnownBitsTest, MulHighBitsOverflow) {
+ unsigned Bits = 4;
+ using KnownUnknownPair = std::pair<int, int>;
+ SmallVector<std::pair<KnownUnknownPair, KnownUnknownPair>> TestPairs = {
+ {{2, 0}, {7, -1}}, // 001?, 0111
+ {{2, -1}, {10, 0}}, // 0010, 101?
+ {{9, 2}, {9, 1}}, // 1?01, 10?1
+ {{5, 1}, {3, 2}}}; // 01?1, 0?11
+ for (auto [P1, P2] : TestPairs) {
+ KnownBits Known1(Bits), Known2(Bits);
+ auto [K1, U1] = P1;
+ auto [K2, U2] = P2;
+ Known1 = KnownBits::makeConstant(APInt(Bits, K1));
+ Known2 = KnownBits::makeConstant(APInt(Bits, K2));
+ if (U1 > -1) {
+ Known1.Zero.setBitVal(U1, 0);
+ Known1.One.setBitVal(U1, 0);
+ }
+ if (U2 > -1) {
+ Known2.Zero.setBitVal(U2, 0);
+ Known2.One.setBitVal(U2, 0);
+ }
+ KnownBits Computed = KnownBits::mul(Known1, Known2);
+ KnownBits Exact(Bits);
+ Exact.Zero.setAllBits();
+ Exact.One.setAllBits();
+
+ ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+ ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+ APInt Res = N1 * N2;
+ Exact.One &= Res;
+ Exact.Zero &= ~Res;
+ });
+ });
+
+ // Check that the leading zeros or ones are optimal for the given examples,
+ // which overflow. It is certainly sub-optimal on other examples.
+ APInt ZerosMask =
+ APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+ OnesMask = APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+ KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+ KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+ ExactZeros.Zero.setAllBits();
+ ExactZeros.One.setAllBits();
+ ExactOnes.Zero.setAllBits();
+ ExactOnes.One.setAllBits();
+
+ ExactZeros.Zero = Exact.Zero & ZerosMask;
+ ExactZeros.One = Exact.One & ZerosMask;
+ ComputedZeros.Zero = Computed.Zero & ZerosMask;
+ ComputedZeros.One = Computed.One & ZerosMask;
+ EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros, {Known1, Known2},
+ /*CheckOptimality=*/true));
+
+ ExactOnes.Zero = Exact.Zero & OnesMask;
+ ExactOnes.One = Exact.One & OnesMask;
+ ComputedOnes.Zero = Computed.Zero & OnesMask;
+ ComputedOnes.One = Computed.One & OnesMask;
+ EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes, {Known1, Known2},
+ /*CheckOptimality=*/true));
+ }
+}
+
+TEST(KnownBitsTest, MulStress) {
+ // Stress test KnownBits::mul on 5 and 6 bits, checking that the result is
+ // correct, even if not optimal.
+ for (unsigned Bits : {5, 6}) {
+ ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+ ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+ KnownBits Computed = KnownBits::mul(Known1, Known2);
+ KnownBits Exact(Bits);
+ Exact.Zero.setAllBits();
+ Exact.One.setAllBits();
+
+ ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+ ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+ APInt Res = N1 * N2;
+ Exact.One &= Res;
+ Exact.Zero &= ~Res;
+ });
+ });
+
+ if (!Exact.hasConflict()) {
+ EXPECT_TRUE(checkResult("mul", Exact, Computed, {Known1, Known2},
+ /*CheckOptimality=*/false));
+ }
+ });
+ });
+ }
+}
} // end anonymous namespace
``````````
</details>
https://github.com/llvm/llvm-project/pull/114211
More information about the llvm-commits
mailing list