[llvm] KnownBits: generalize high-bits of mul to overflows (PR #114211)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 30 08:10:33 PDT 2024


================
@@ -796,19 +796,76 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
   assert((!NoUndefSelfMultiply || LHS == RHS) &&
          "Self multiplication knownbits mismatch");
 
-  // Compute the high known-0 bits by multiplying the unsigned max of each side.
-  // Conservatively, M active bits * N active bits results in M + N bits in the
-  // result. But if we know a value is a power-of-2 for example, then this
-  // computes one more leading zero.
-  // TODO: This could be generalized to number of sign bits (negative numbers).
-  APInt UMaxLHS = LHS.getMaxValue();
-  APInt UMaxRHS = RHS.getMaxValue();
-
-  // For leading zeros in the result to be valid, the unsigned max product must
-  // fit in the bitwidth (it must not overflow).
+  // Compute the high known-0 or known-1 bits by multiplying the min and max of
+  // each side.
+  APInt MaxLHS = LHS.isNegative() ? LHS.getMinValue().abs() : LHS.getMaxValue(),
+        MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue(),
+        MinLHS = LHS.isNegative() ? LHS.getMaxValue().abs() : LHS.getMinValue(),
+        MinRHS = RHS.isNegative() ? RHS.getMaxValue().abs() : RHS.getMinValue();
+
+  // If MaxProduct doesn't overflow, it implies that MinProduct also won't
+  // overflow. However, if MaxProduct overflows, there is no guarantee on the
+  // MinProduct overflowing.
   bool HasOverflow;
-  APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
-  unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
+  APInt MaxProduct = MaxLHS.umul_ov(MaxRHS, HasOverflow),
+        MinProduct = MinLHS * MinRHS;
+
+  if (LHS.isNegative() != RHS.isNegative()) {
+    // The unsigned-multiplication wrapped MinProduct and MaxProduct can be
+    // negated to turn them into the corresponding signed-multiplication
+    // wrapped values.
+    MinProduct.negate();
+    MaxProduct.negate();
+
+    // MinProduct < MaxProduct is now MaxProduct < MinProduct.
+    std::swap(MinProduct, MaxProduct);
+  }
+
+  // Unless both MinProduct and MaxProduct are the same sign, there won't be any
+  // leading zeros or ones in the result.
+  unsigned LeadZ = 0, LeadO = 0;
+  if (MinProduct.isNegative() == MaxProduct.isNegative()) {
+    APInt LHSUnknown = (~LHS.Zero & ~LHS.One),
+          RHSUnknown = (~RHS.Zero & ~RHS.One);
+
+    // A product of M active bits * N active bits results in M + N bits in the
+    // result. If either of the operands is a power of two, the result has one
+    // less active bit.
+    auto ProdActiveBits = [](const APInt &A, const APInt &B) -> unsigned {
+      if (A.isZero() || B.isZero())
+        return 0;
+      return A.getActiveBits() + B.getActiveBits() -
+             (A.isPowerOf2() || B.isPowerOf2());
+    };
+
+    // We want to compute the number of active bits in the difference between
+    // the non-wrapped max product and non-wrapped min product, but we want to
+    // avoid camputing the non-wrapped max/min product.
+    unsigned ActiveBitsInDiff;
+    if (MinLHS.isZero() && MinRHS.isZero())
+      ActiveBitsInDiff = ProdActiveBits(LHSUnknown, RHSUnknown);
+    else
+      ActiveBitsInDiff =
+          ProdActiveBits(MinLHS.isZero() ? LHSUnknown : MinLHS, RHSUnknown) +
+          ProdActiveBits(MinRHS.isZero() ? RHSUnknown : MinRHS, LHSUnknown);
----------------
artagnon wrote:

There is a bug here. Marking as draft until I work it out.

https://github.com/llvm/llvm-project/pull/114211


More information about the llvm-commits mailing list