[llvm] KnownBits: generalize high-bits of mul to overflows (PR #114211)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 30 04:37:27 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-support

Author: Ramkumar Ramachandra (artagnon)

<details>
<summary>Changes</summary>

Make the non-overflow case of KnownBits::mul optimal, and smoothly generalize it to the case when overflow occurs by relying on min-product in addition to max-product, noting that it cannot possibly be optimal unless we also look at the bits in between min-product and max-product.

-- 8< --
Based on #<!-- -->113051.

---
Full diff: https://github.com/llvm/llvm-project/pull/114211.diff


3 Files Affected:

- (modified) llvm/lib/Support/KnownBits.cpp (+88-13) 
- (added) llvm/test/Analysis/ValueTracking/knownbits-mul.ll (+143) 
- (modified) llvm/unittests/Support/KnownBitsTest.cpp (+148-1) 


``````````diff
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 89668af378070b..c2d7c776725088 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -796,19 +796,93 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
   assert((!NoUndefSelfMultiply || LHS == RHS) &&
          "Self multiplication knownbits mismatch");
 
-  // Compute the high known-0 bits by multiplying the unsigned max of each side.
-  // Conservatively, M active bits * N active bits results in M + N bits in the
-  // result. But if we know a value is a power-of-2 for example, then this
-  // computes one more leading zero.
-  // TODO: This could be generalized to number of sign bits (negative numbers).
-  APInt UMaxLHS = LHS.getMaxValue();
-  APInt UMaxRHS = RHS.getMaxValue();
-
-  // For leading zeros in the result to be valid, the unsigned max product must
-  // fit in the bitwidth (it must not overflow).
+  // Compute the high known-0 or known-1 bits by multiplying the min and max of
+  // each side.
+  APInt MaxLHS = LHS.isNegative() ? LHS.getMinValue().abs() : LHS.getMaxValue(),
+        MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue(),
+        MinLHS = LHS.isNegative() ? LHS.getMaxValue().abs() : LHS.getMinValue(),
+        MinRHS = RHS.isNegative() ? RHS.getMaxValue().abs() : RHS.getMinValue();
+
+  // If MaxProduct doesn't overflow, it implies that MinProduct also won't
+  // overflow. However, if MaxProduct overflows, there is no guarantee on the
+  // MinProduct overflowing.
   bool HasOverflow;
-  APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
-  unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
+  APInt MaxProduct = MaxLHS.umul_ov(MaxRHS, HasOverflow),
+        MinProduct = MinLHS * MinRHS;
+
+  bool OpsSignMatch = LHS.isNegative() == RHS.isNegative();
+  if (!OpsSignMatch) {
+    // The unsigned-multiplication wrapped MinProduct and MaxProduct can be
+    // negated to turn them into the corresponding signed-multiplication
+    // wrapped values.
+    MinProduct.negate();
+    MaxProduct.negate();
+  }
+
+  // Unless both MinProduct and MaxProduct are the same sign, there won't be any
+  // leading zeros or ones in the result.
+  unsigned LeadZ = 0, LeadO = 0;
+  if (MinProduct.isNegative() == MaxProduct.isNegative()) {
+    APInt LHSUnknown = (~LHS.Zero & ~LHS.One),
+          RHSUnknown = (~RHS.Zero & ~RHS.One);
+
+    // A product of M active bits * N active bits results in M + N bits in the
+    // result. If either of the operands is a power of two, the result has one
+    // less active bit.
+    auto ProdActiveBits = [](const APInt &A, const APInt &B) -> unsigned {
+      if (A.isZero() || B.isZero())
+        return 0;
+      return A.getActiveBits() + B.getActiveBits() -
+             (A.isPowerOf2() || B.isPowerOf2());
+    };
+
+    // We want to compute the number of active bits in the difference between
+    // the non-wrapped max product and non-wrapped min product, but we want to
+    // avoid camputing the non-wrapped max/min product.
+    unsigned ActiveBitsInDiff;
+    if (MinLHS.isZero() && MinRHS.isZero())
+      ActiveBitsInDiff = ProdActiveBits(LHSUnknown, RHSUnknown);
+    else
+      ActiveBitsInDiff =
+          ProdActiveBits(MinLHS.isZero() ? LHSUnknown : MinLHS, RHSUnknown) +
+          ProdActiveBits(MinRHS.isZero() ? RHSUnknown : MinRHS, LHSUnknown);
+
+    // We uniformly handle the case where there is no max-overflow, in which
+    // case the high zeros and ones are computed optimally, and where there is,
+    // but the result shifts at most by BitWidth, in which case the high zeros
+    // and ones are not computed optimally.
+    if (!HasOverflow || ActiveBitsInDiff <= BitWidth) {
+      // Checks that A.ugt(B), excluding the degenerate case where A is all-ones
+      // and B is zero.
+      auto UgtCheckCorner = [](const APInt &A, const APInt &B) {
+        return (!A.isAllOnes() || !B.isZero()) && A.ugt(B);
+      };
+
+      // If we're shifting by BitWidth, MaxProduct and MinProduct are swapped.
+      bool MinMaxSwap = ActiveBitsInDiff == BitWidth;
+      if (MinMaxSwap)
+        std::swap(MaxProduct, MinProduct);
+
+      if (OpsSignMatch != MinMaxSwap) {
+        // Normally, this is the case for when the signs of LHS and RHS match,
+        // and the else branch is for when the signs mismatch. However, if min
+        // and max were swapped, we need to invert these cases.
+        if (UgtCheckCorner(MaxProduct, MinProduct)) {
+          // Normally, when the signs of LHS and RHS match, we can safely set
+          // leading zeros of the result. However, if both MaxProduct and
+          // MinProduct are negative, we can also set the leading ones.
+          LeadZ = MaxProduct.countLeadingZeros();
+          LeadO = (MaxProduct & MinProduct).countLeadingOnes();
+        }
+      } else if (UgtCheckCorner(MinProduct, MaxProduct)) {
+        // Normally, when the signs of LHS and RHS mismatch, we can safely set
+        // leading ones of the result. However, if both MaxProduct and
+        // MinProduct are non-negative, we can also set the leading zeros.
+        LeadO = MaxProduct.countLeadingOnes();
+        LeadZ = (MaxProduct | MinProduct).countLeadingZeros();
+      }
+    }
+  }
 
   // The result of the bottom bits of an integer multiply can be
   // inferred by looking at the bottom bits of both operands and
@@ -873,8 +947,9 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
 
   KnownBits Res(BitWidth);
   Res.Zero.setHighBits(LeadZ);
+  Res.One.setHighBits(LeadO);
   Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
-  Res.One = BottomKnown.getLoBits(ResultBitsKnown);
+  Res.One |= BottomKnown.getLoBits(ResultBitsKnown);
 
   // If we're self-multiplying then bit[1] is guaranteed to be zero.
   if (NoUndefSelfMultiply && BitWidth > 1) {
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-mul.ll b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
new file mode 100644
index 00000000000000..37526c67f0d9e1
--- /dev/null
+++ b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
@@ -0,0 +1,143 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @mul_low_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = and i8 %xx, 2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 6
+  ret i8 %r
+}
+
+define i8 @mul_low_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = or i8 %xx, -2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 2
+  ret i8 %r
+}
+
+define i8 @mul_low_bits_partially_known(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_partially_known(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 2
+; CHECK-NEXT:    [[MUL:%.*]] = sub nsw i8 0, [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 2
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -4
+  %x.notsmin = or i8 %x, 3
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x.notsmin, %y
+  %r = and i8 %mul, 6
+  ret i8 %r
+}
+
+define i8 @mul_low_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], 4
+; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 6
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 6
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -4
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 6
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = and i8 %xx, 2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 16
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 -16
+;
+  %x = or i8 %xx, -2
+  %y = and i8 %yy, 4
+  %y.nonzero = or i8 %y, 1
+  %mul = mul i8 %x, %y.nonzero
+  %r = and i8 %mul, -16
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_know3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = or i8 %xx, -4
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, -16
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = and i8 [[XX]], 2
+; CHECK-NEXT:    [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw nsw i8 [[X]], [[Y]]
+; CHECK-NEXT:    ret i8 [[MUL]]
+;
+  %x = and i8 %xx, 2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 8
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], -2
+; CHECK-NEXT:    [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], -16
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, -16
+  ret i8 %r
+}
+
+; TODO: This can be reduced to zero.
+define i8 @mul_high_bits_unknown3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], 28
+; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 30
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 16
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -4
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 16
+  ret i8 %r
+}
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b16368de176481..2be2e1d093315c 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -815,7 +815,7 @@ TEST(KnownBitsTest, ConcatBits) {
   }
 }
 
-TEST(KnownBitsTest, MulExhaustive) {
+TEST(KnownBitsTest, MulLowBitsExhaustive) {
   for (unsigned Bits : {1, 4}) {
     ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
       ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
@@ -849,4 +849,151 @@ TEST(KnownBitsTest, MulExhaustive) {
   }
 }
 
+TEST(KnownBitsTest, MulHighBitsNoOverflow) {
+  for (unsigned Bits : {1, 4}) {
+    ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+      ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+        KnownBits Computed = KnownBits::mul(Known1, Known2);
+        KnownBits Exact(Bits), WideExact(2 * Bits);
+        Exact.Zero.setAllBits();
+        Exact.One.setAllBits();
+
+        bool HasOverflow;
+        ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+          ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+            // The final value of HasOverflow corresponds to the multiplication
+            // in the last iteration, which is the max product.
+            APInt Res = N1.umul_ov(N2, HasOverflow);
+            Exact.One &= Res;
+            Exact.Zero &= ~Res;
+          });
+        });
+
+        if (!Exact.hasConflict() && !HasOverflow) {
+          // Check that leading zeros and leading ones are optimal in the
+          // result, provided there is no overflow.
+          APInt ZerosMask =
+                    APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+                OnesMask =
+                    APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+          KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+          KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+          ExactZeros.Zero.setAllBits();
+          ExactZeros.One.setAllBits();
+          ExactOnes.Zero.setAllBits();
+          ExactOnes.One.setAllBits();
+
+          ExactZeros.Zero = Exact.Zero & ZerosMask;
+          ExactZeros.One = Exact.One & ZerosMask;
+          ComputedZeros.Zero = Computed.Zero & ZerosMask;
+          ComputedZeros.One = Computed.One & ZerosMask;
+          EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros,
+                                  {Known1, Known2},
+                                  /*CheckOptimality=*/true));
+
+          ExactOnes.Zero = Exact.Zero & OnesMask;
+          ExactOnes.One = Exact.One & OnesMask;
+          ComputedOnes.Zero = Computed.Zero & OnesMask;
+          ComputedOnes.One = Computed.One & OnesMask;
+          EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes,
+                                  {Known1, Known2},
+                                  /*CheckOptimality=*/true));
+        }
+      });
+    });
+  }
+}
+
+TEST(KnownBitsTest, MulHighBitsOverflow) {
+  unsigned Bits = 4;
+  using KnownUnknownPair = std::pair<int, int>;
+  SmallVector<std::pair<KnownUnknownPair, KnownUnknownPair>> TestPairs = {
+      {{2, 0}, {7, -1}},  // 001?, 0111
+      {{2, -1}, {10, 0}}, // 0010, 101?
+      {{9, 2}, {9, 1}},   // 1?01, 10?1
+      {{5, 1}, {3, 2}}};  // 01?1, 0?11
+  for (auto [P1, P2] : TestPairs) {
+    KnownBits Known1(Bits), Known2(Bits);
+    auto [K1, U1] = P1;
+    auto [K2, U2] = P2;
+    Known1 = KnownBits::makeConstant(APInt(Bits, K1));
+    Known2 = KnownBits::makeConstant(APInt(Bits, K2));
+    if (U1 > -1) {
+      Known1.Zero.setBitVal(U1, 0);
+      Known1.One.setBitVal(U1, 0);
+    }
+    if (U2 > -1) {
+      Known2.Zero.setBitVal(U2, 0);
+      Known2.One.setBitVal(U2, 0);
+    }
+    KnownBits Computed = KnownBits::mul(Known1, Known2);
+    KnownBits Exact(Bits);
+    Exact.Zero.setAllBits();
+    Exact.One.setAllBits();
+
+    ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+      ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+        APInt Res = N1 * N2;
+        Exact.One &= Res;
+        Exact.Zero &= ~Res;
+      });
+    });
+
+    // Check that the leading zeros or ones are optimal for the given examples,
+    // which overflow. It is certainly sub-optimal on other examples.
+    APInt ZerosMask =
+              APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+          OnesMask = APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+    KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+    KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+    ExactZeros.Zero.setAllBits();
+    ExactZeros.One.setAllBits();
+    ExactOnes.Zero.setAllBits();
+    ExactOnes.One.setAllBits();
+
+    ExactZeros.Zero = Exact.Zero & ZerosMask;
+    ExactZeros.One = Exact.One & ZerosMask;
+    ComputedZeros.Zero = Computed.Zero & ZerosMask;
+    ComputedZeros.One = Computed.One & ZerosMask;
+    EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros, {Known1, Known2},
+                            /*CheckOptimality=*/true));
+
+    ExactOnes.Zero = Exact.Zero & OnesMask;
+    ExactOnes.One = Exact.One & OnesMask;
+    ComputedOnes.Zero = Computed.Zero & OnesMask;
+    ComputedOnes.One = Computed.One & OnesMask;
+    EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes, {Known1, Known2},
+                            /*CheckOptimality=*/true));
+  }
+}
+
+TEST(KnownBitsTest, MulStress) {
+  // Stress test KnownBits::mul on 5 and 6 bits, checking that the result is
+  // correct, even if not optimal.
+  for (unsigned Bits : {5, 6}) {
+    ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+      ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+        KnownBits Computed = KnownBits::mul(Known1, Known2);
+        KnownBits Exact(Bits);
+        Exact.Zero.setAllBits();
+        Exact.One.setAllBits();
+
+        ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+          ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+            APInt Res = N1 * N2;
+            Exact.One &= Res;
+            Exact.Zero &= ~Res;
+          });
+        });
+
+        if (!Exact.hasConflict()) {
+          EXPECT_TRUE(checkResult("mul", Exact, Computed, {Known1, Known2},
+                                  /*CheckOptimality=*/false));
+        }
+      });
+    });
+  }
+}
 } // end anonymous namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/114211


More information about the llvm-commits mailing list