[llvm] KnownBits: generalize high-bits of mul to overflows (PR #114211)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 31 15:47:50 PDT 2024


https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/114211

>From 9459ba972c4caf45679b50cef9f0735099b73255 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Sat, 19 Oct 2024 13:34:23 +0100
Subject: [PATCH 1/6] ValueTracking/test: cover known bits of mul

---
 .../Analysis/ValueTracking/knownbits-mul.ll   | 152 ++++++++++++++++++
 1 file changed, 152 insertions(+)
 create mode 100644 llvm/test/Analysis/ValueTracking/knownbits-mul.ll

diff --git a/llvm/test/Analysis/ValueTracking/knownbits-mul.ll b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
new file mode 100644
index 00000000000000..79df43c99744e6
--- /dev/null
+++ b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
@@ -0,0 +1,152 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @mul_low_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = and i8 %xx, 2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 6
+  ret i8 %r
+}
+
+define i8 @mul_low_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = or i8 %xx, -2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 2
+  ret i8 %r
+}
+
+define i8 @mul_low_bits_partially_known(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_partially_known(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 2
+; CHECK-NEXT:    [[MUL:%.*]] = sub nsw i8 0, [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 2
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -4
+  %x.notsmin = or i8 %x, 3
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x.notsmin, %y
+  %r = and i8 %mul, 6
+  ret i8 %r
+}
+
+define i8 @mul_low_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], 4
+; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 6
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 6
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -4
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 6
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = and i8 %xx, 2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 16
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], -2
+; CHECK-NEXT:    [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT:    [[Y_NONZERO:%.*]] = or disjoint i8 [[Y]], 1
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i8 [[X]], [[Y_NONZERO]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], -16
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -2
+  %y = and i8 %yy, 4
+  %y.nonzero = or i8 %y, 1
+  %mul = mul i8 %x, %y.nonzero
+  %r = and i8 %mul, -16
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_know3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], 124
+; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 126
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 112
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -4
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, -16
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = and i8 [[XX]], 2
+; CHECK-NEXT:    [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw nsw i8 [[X]], [[Y]]
+; CHECK-NEXT:    ret i8 [[MUL]]
+;
+  %x = and i8 %xx, 2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 8
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], -2
+; CHECK-NEXT:    [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], -16
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, -16
+  ret i8 %r
+}
+
+; TODO: This can be reduced to zero.
+define i8 @mul_high_bits_unknown3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], 28
+; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 30
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 16
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -4
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 16
+  ret i8 %r
+}

>From b6b0cfdf0fafddd82638675e80a461e02c3c5198 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Fri, 18 Oct 2024 23:25:07 +0100
Subject: [PATCH 2/6] KnownBits: refine high-bits of mul in signed case

KnownBits::mul suffers from the deficiency that it doesn't account for
signed inputs. Fix it by refining known leading zeros when both inputs
are signed, and setting known leading ones when one of the inputs is
signed. The strategy we've used is to still use umul_ov, after adjusting
for signed inputs, and setting known leading ones from the negation of
the result, when it is known to be negative, noting that a possibly-zero
result is a special case.
---
 llvm/lib/Support/KnownBits.cpp                | 32 ++++++++++++-------
 .../Analysis/ValueTracking/knownbits-mul.ll   | 13 ++------
 2 files changed, 22 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 89668af378070b..b63945f202a34d 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -796,19 +796,26 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
   assert((!NoUndefSelfMultiply || LHS == RHS) &&
          "Self multiplication knownbits mismatch");
 
-  // Compute the high known-0 bits by multiplying the unsigned max of each side.
-  // Conservatively, M active bits * N active bits results in M + N bits in the
-  // result. But if we know a value is a power-of-2 for example, then this
-  // computes one more leading zero.
-  // TODO: This could be generalized to number of sign bits (negative numbers).
-  APInt UMaxLHS = LHS.getMaxValue();
-  APInt UMaxRHS = RHS.getMaxValue();
-
-  // For leading zeros in the result to be valid, the unsigned max product must
+  // Compute the high known-0 or known-1 bits by multiplying the max of each
+  // side. Conservatively, M active bits * N active bits results in M + N bits
+  // in the result. But if we know a value is a power-of-2 for example, then
+  // this computes one more leading zero or one.
+  APInt MaxLHS = LHS.isNegative() ? LHS.getMinValue().abs() : LHS.getMaxValue(),
+        MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue();
+
+  // For leading zeros or ones in the result to be valid, the max product must
   // fit in the bitwidth (it must not overflow).
   bool HasOverflow;
-  APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
-  unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
+  APInt Result = MaxLHS.umul_ov(MaxRHS, HasOverflow);
+  bool NegResult = LHS.isNegative() ^ RHS.isNegative();
+  unsigned LeadZ = 0, LeadO = 0;
+  if (!HasOverflow) {
+    // Do not set leading ones unless the result is known to be non-zero.
+    if (NegResult && LHS.isNonZero() && RHS.isNonZero())
+      LeadO = (-Result).countLeadingOnes();
+    else if (!NegResult)
+      LeadZ = Result.countLeadingZeros();
+  }
 
   // The result of the bottom bits of an integer multiply can be
   // inferred by looking at the bottom bits of both operands and
@@ -873,8 +880,9 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
 
   KnownBits Res(BitWidth);
   Res.Zero.setHighBits(LeadZ);
+  Res.One.setHighBits(LeadO);
   Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
-  Res.One = BottomKnown.getLoBits(ResultBitsKnown);
+  Res.One |= BottomKnown.getLoBits(ResultBitsKnown);
 
   // If we're self-multiplying then bit[1] is guaranteed to be zero.
   if (NoUndefSelfMultiply && BitWidth > 1) {
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-mul.ll b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
index 79df43c99744e6..37526c67f0d9e1 100644
--- a/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
+++ b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
@@ -72,12 +72,7 @@ define i8 @mul_high_bits_know(i8 %xx, i8 %yy) {
 define i8 @mul_high_bits_know2(i8 %xx, i8 %yy) {
 ; CHECK-LABEL: define i8 @mul_high_bits_know2(
 ; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
-; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], -2
-; CHECK-NEXT:    [[Y:%.*]] = and i8 [[YY]], 4
-; CHECK-NEXT:    [[Y_NONZERO:%.*]] = or disjoint i8 [[Y]], 1
-; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i8 [[X]], [[Y_NONZERO]]
-; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], -16
-; CHECK-NEXT:    ret i8 [[R]]
+; CHECK-NEXT:    ret i8 -16
 ;
   %x = or i8 %xx, -2
   %y = and i8 %yy, 4
@@ -90,11 +85,7 @@ define i8 @mul_high_bits_know2(i8 %xx, i8 %yy) {
 define i8 @mul_high_bits_know3(i8 %xx, i8 %yy) {
 ; CHECK-LABEL: define i8 @mul_high_bits_know3(
 ; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
-; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], 124
-; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 126
-; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X]], [[Y]]
-; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 112
-; CHECK-NEXT:    ret i8 [[R]]
+; CHECK-NEXT:    ret i8 0
 ;
   %x = or i8 %xx, -4
   %y = or i8 %yy, -2

>From 901b6e55b70fa3075eaa3e9bee1eb2f5950fe722 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Wed, 23 Oct 2024 11:31:00 +0100
Subject: [PATCH 3/6] KnownBitsTest: cover in unittests; address review

---
 llvm/unittests/Support/KnownBitsTest.cpp | 52 +++++++++++++++++++++++-
 1 file changed, 51 insertions(+), 1 deletion(-)

diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b16368de176481..e374b46492622c 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -815,7 +815,7 @@ TEST(KnownBitsTest, ConcatBits) {
   }
 }
 
-TEST(KnownBitsTest, MulExhaustive) {
+TEST(KnownBitsTest, MulLowBitsExhaustive) {
   for (unsigned Bits : {1, 4}) {
     ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
       ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
@@ -849,4 +849,54 @@ TEST(KnownBitsTest, MulExhaustive) {
   }
 }
 
+TEST(KnownBitsTest, MulHighBits) {
+  unsigned Bits = 8;
+  SmallVector<std::pair<int, int>, 4> TestPairs = {
+      {2, 4}, {-2, -4}, {2, -4}, {-2, 4}};
+  for (auto [K1, K2] : TestPairs) {
+    KnownBits Known1(Bits), Known2(Bits);
+    if (K1 > 0) {
+      // If we only set the zeros of ~K1, Known1 could be zero. Avoid this case,
+      // as we can only set leading ones in the case where LHS and RHS have
+      // different signs, when the result is known non-zero.
+      Known1.Zero |= ~(K1 | 1);
+      Known1.One |= 1;
+    } else {
+      Known1.One |= K1;
+    }
+    if (K2 > 0) {
+      // If we only set the zeros of ~K1, Known1 could be zero. Avoid this case,
+      // as we can only set leading ones in the case where LHS and RHS have
+      // different signs, when the result is known non-zero.
+      Known2.Zero |= ~(K2 | 1);
+      Known2.One |= 1;
+    } else {
+      Known2.One |= K2;
+    }
+    KnownBits Computed = KnownBits::mul(Known1, Known2);
+    KnownBits Exact(Bits);
+    Exact.Zero.setAllBits();
+    Exact.One.setAllBits();
+
+    ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+      ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+        APInt Res = N1 * N2;
+        Exact.One &= Res;
+        Exact.Zero &= ~Res;
+      });
+    });
+
+    // Check that the high bits are optimal, with the caveat that mul_ov of LHS
+    // and RHS doesn't overflow, which is the case for our TestPairs.
+    APInt Mask = APInt::getHighBitsSet(
+        Bits, (Exact.Zero | Exact.One).countLeadingOnes());
+    Exact.Zero &= Mask;
+    Exact.One &= Mask;
+    Computed.Zero &= Mask;
+    Computed.One &= Mask;
+    EXPECT_TRUE(checkResult("mul", Exact, Computed, {Known1, Known2},
+                            /*CheckOptimality=*/true));
+  }
+}
+
 } // end anonymous namespace

>From 603ec715ef081fd6683a888ea940548b807ef28f Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Wed, 23 Oct 2024 17:21:45 +0100
Subject: [PATCH 4/6] KnownBits: address review; more concise

---
 llvm/lib/Support/KnownBits.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index b63945f202a34d..bed1a45568c1c6 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -807,14 +807,13 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
   // fit in the bitwidth (it must not overflow).
   bool HasOverflow;
   APInt Result = MaxLHS.umul_ov(MaxRHS, HasOverflow);
-  bool NegResult = LHS.isNegative() ^ RHS.isNegative();
   unsigned LeadZ = 0, LeadO = 0;
   if (!HasOverflow) {
+    if (LHS.isNegative() == RHS.isNegative())
+      LeadZ = Result.countLeadingZeros();
     // Do not set leading ones unless the result is known to be non-zero.
-    if (NegResult && LHS.isNonZero() && RHS.isNonZero())
+    else if (LHS.isNonZero() && RHS.isNonZero())
       LeadO = (-Result).countLeadingOnes();
-    else if (!NegResult)
-      LeadZ = Result.countLeadingZeros();
   }
 
   // The result of the bottom bits of an integer multiply can be

>From 04f1601eb5ce86e81865017d4b4ff74d3201b00f Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 24 Oct 2024 17:19:08 +0100
Subject: [PATCH 5/6] KnownBits: generalize high-bits of mul to overflows

Make the non-overflow case of KnownBits::mul optimal, and smoothly
generalize it to the case when overflow occurs by relying on min-product
in addition to max-product, noting that it cannot possibly be optimal
unless we also look at the bits in between min-product and max-product.
---
 llvm/lib/Support/KnownBits.cpp           |  79 ++++++++++--
 llvm/unittests/Support/KnownBitsTest.cpp | 157 ++++++++++++++++++-----
 2 files changed, 192 insertions(+), 44 deletions(-)

diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index bed1a45568c1c6..44261c9aa7b567 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -796,24 +796,75 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
   assert((!NoUndefSelfMultiply || LHS == RHS) &&
          "Self multiplication knownbits mismatch");
 
-  // Compute the high known-0 or known-1 bits by multiplying the max of each
-  // side. Conservatively, M active bits * N active bits results in M + N bits
-  // in the result. But if we know a value is a power-of-2 for example, then
-  // this computes one more leading zero or one.
+  // Compute the high known-0 or known-1 bits by multiplying the min and max of
+  // each side.
   APInt MaxLHS = LHS.isNegative() ? LHS.getMinValue().abs() : LHS.getMaxValue(),
-        MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue();
+        MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue(),
+        MinLHS = LHS.isNegative() ? LHS.getMaxValue().abs() : LHS.getMinValue(),
+        MinRHS = RHS.isNegative() ? RHS.getMaxValue().abs() : RHS.getMinValue();
 
-  // For leading zeros or ones in the result to be valid, the max product must
-  // fit in the bitwidth (it must not overflow).
+  // If MaxProduct doesn't overflow, it implies that MinProduct also won't
+  // overflow. However, if MaxProduct overflows, there is no guarantee on the
+  // MinProduct overflowing.
   bool HasOverflow;
-  APInt Result = MaxLHS.umul_ov(MaxRHS, HasOverflow);
+  APInt MaxProduct = MaxLHS.umul_ov(MaxRHS, HasOverflow),
+        MinProduct = MinLHS * MinRHS;
+
+  if (LHS.isNegative() != RHS.isNegative()) {
+    // The unsigned-multiplication wrapped MinProduct and MaxProduct can be
+    // negated to turn them into the corresponding signed-multiplication
+    // wrapped values.
+    MinProduct.negate();
+    MaxProduct.negate();
+
+    // MinProduct < MaxProduct is now MaxProduct < MinProduct.
+    std::swap(MinProduct, MaxProduct);
+  }
+
+  // Unless both MinProduct and MaxProduct are the same sign, there won't be any
+  // leading zeros or ones in the result.
   unsigned LeadZ = 0, LeadO = 0;
-  if (!HasOverflow) {
-    if (LHS.isNegative() == RHS.isNegative())
-      LeadZ = Result.countLeadingZeros();
-    // Do not set leading ones unless the result is known to be non-zero.
-    else if (LHS.isNonZero() && RHS.isNonZero())
-      LeadO = (-Result).countLeadingOnes();
+  if (MinProduct.isNegative() == MaxProduct.isNegative()) {
+    APInt LHSUnknown = (~LHS.Zero & ~LHS.One),
+          RHSUnknown = (~RHS.Zero & ~RHS.One);
+
+    // A product of M active bits * N active bits results in M + N bits in the
+    // result. If either of the operands is a power of two, the result has one
+    // less active bit.
+    auto ProdActiveBits = [](const APInt &A, const APInt &B) -> unsigned {
+      if (A.isZero() || B.isZero())
+        return 0;
+      return A.getActiveBits() + B.getActiveBits() -
+             (A.isPowerOf2() || B.isPowerOf2());
+    };
+
+    // We want to compute the number of active bits in the difference between
+    // the non-wrapped max product and non-wrapped min product, but we want to
+    // avoid camputing the non-wrapped max/min product.
+    unsigned ActiveBitsInDiff;
+    if (MinLHS.isZero() && MinRHS.isZero())
+      ActiveBitsInDiff = ProdActiveBits(LHSUnknown, RHSUnknown);
+    else
+      ActiveBitsInDiff =
+          ProdActiveBits(MinLHS.isZero() ? LHSUnknown : MinLHS, RHSUnknown) +
+          ProdActiveBits(MinRHS.isZero() ? RHSUnknown : MinRHS, LHSUnknown);
+
+    // Checks that A.ugt(B), excluding the degenerate case where A is all-ones
+    // and B is zero.
+    auto UgtCheckCorner = [](const APInt &A, const APInt &B) {
+      return (!A.isAllOnes() || !B.isZero()) && A.ugt(B);
+    };
+
+    // We uniformly handle the case where there is no max-overflow, in which
+    // case the high zeros and ones are computed optimally, and where there is,
+    // but the result shifts at most by BitWidth, in which case the high zeros
+    // and ones are not computed optimally.
+    if ((!HasOverflow || ActiveBitsInDiff <= BitWidth) &&
+        UgtCheckCorner(MaxProduct, MinProduct)) {
+      // Set the minimum leading zeros or ones from MaxProduct and MinProduct.
+      LeadZ = MaxProduct.countLeadingZeros();
+      LeadO = MinProduct.countLeadingOnes();
+    }
   }
 
   // The result of the bottom bits of an integer multiply can be
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index e374b46492622c..2be2e1d093315c 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -849,29 +849,83 @@ TEST(KnownBitsTest, MulLowBitsExhaustive) {
   }
 }
 
-TEST(KnownBitsTest, MulHighBits) {
-  unsigned Bits = 8;
-  SmallVector<std::pair<int, int>, 4> TestPairs = {
-      {2, 4}, {-2, -4}, {2, -4}, {-2, 4}};
-  for (auto [K1, K2] : TestPairs) {
+TEST(KnownBitsTest, MulHighBitsNoOverflow) {
+  for (unsigned Bits : {1, 4}) {
+    ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+      ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+        KnownBits Computed = KnownBits::mul(Known1, Known2);
+        KnownBits Exact(Bits), WideExact(2 * Bits);
+        Exact.Zero.setAllBits();
+        Exact.One.setAllBits();
+
+        bool HasOverflow;
+        ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+          ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+            // The final value of HasOverflow corresponds to the multiplication
+            // in the last iteration, which is the max product.
+            APInt Res = N1.umul_ov(N2, HasOverflow);
+            Exact.One &= Res;
+            Exact.Zero &= ~Res;
+          });
+        });
+
+        if (!Exact.hasConflict() && !HasOverflow) {
+          // Check that leading zeros and leading ones are optimal in the
+          // result, provided there is no overflow.
+          APInt ZerosMask =
+                    APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+                OnesMask =
+                    APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+          KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+          KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+          ExactZeros.Zero.setAllBits();
+          ExactZeros.One.setAllBits();
+          ExactOnes.Zero.setAllBits();
+          ExactOnes.One.setAllBits();
+
+          ExactZeros.Zero = Exact.Zero & ZerosMask;
+          ExactZeros.One = Exact.One & ZerosMask;
+          ComputedZeros.Zero = Computed.Zero & ZerosMask;
+          ComputedZeros.One = Computed.One & ZerosMask;
+          EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros,
+                                  {Known1, Known2},
+                                  /*CheckOptimality=*/true));
+
+          ExactOnes.Zero = Exact.Zero & OnesMask;
+          ExactOnes.One = Exact.One & OnesMask;
+          ComputedOnes.Zero = Computed.Zero & OnesMask;
+          ComputedOnes.One = Computed.One & OnesMask;
+          EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes,
+                                  {Known1, Known2},
+                                  /*CheckOptimality=*/true));
+        }
+      });
+    });
+  }
+}
+
+TEST(KnownBitsTest, MulHighBitsOverflow) {
+  unsigned Bits = 4;
+  using KnownUnknownPair = std::pair<int, int>;
+  SmallVector<std::pair<KnownUnknownPair, KnownUnknownPair>> TestPairs = {
+      {{2, 0}, {7, -1}},  // 001?, 0111
+      {{2, -1}, {10, 0}}, // 0010, 101?
+      {{9, 2}, {9, 1}},   // 1?01, 10?1
+      {{5, 1}, {3, 2}}};  // 01?1, 0?11
+  for (auto [P1, P2] : TestPairs) {
     KnownBits Known1(Bits), Known2(Bits);
-    if (K1 > 0) {
-      // If we only set the zeros of ~K1, Known1 could be zero. Avoid this case,
-      // as we can only set leading ones in the case where LHS and RHS have
-      // different signs, when the result is known non-zero.
-      Known1.Zero |= ~(K1 | 1);
-      Known1.One |= 1;
-    } else {
-      Known1.One |= K1;
+    auto [K1, U1] = P1;
+    auto [K2, U2] = P2;
+    Known1 = KnownBits::makeConstant(APInt(Bits, K1));
+    Known2 = KnownBits::makeConstant(APInt(Bits, K2));
+    if (U1 > -1) {
+      Known1.Zero.setBitVal(U1, 0);
+      Known1.One.setBitVal(U1, 0);
     }
-    if (K2 > 0) {
-      // If we only set the zeros of ~K1, Known1 could be zero. Avoid this case,
-      // as we can only set leading ones in the case where LHS and RHS have
-      // different signs, when the result is known non-zero.
-      Known2.Zero |= ~(K2 | 1);
-      Known2.One |= 1;
-    } else {
-      Known2.One |= K2;
+    if (U2 > -1) {
+      Known2.Zero.setBitVal(U2, 0);
+      Known2.One.setBitVal(U2, 0);
     }
     KnownBits Computed = KnownBits::mul(Known1, Known2);
     KnownBits Exact(Bits);
@@ -886,17 +940,60 @@ TEST(KnownBitsTest, MulHighBits) {
       });
     });
 
-    // Check that the high bits are optimal, with the caveat that mul_ov of LHS
-    // and RHS doesn't overflow, which is the case for our TestPairs.
-    APInt Mask = APInt::getHighBitsSet(
-        Bits, (Exact.Zero | Exact.One).countLeadingOnes());
-    Exact.Zero &= Mask;
-    Exact.One &= Mask;
-    Computed.Zero &= Mask;
-    Computed.One &= Mask;
-    EXPECT_TRUE(checkResult("mul", Exact, Computed, {Known1, Known2},
+    // Check that the leading zeros or ones are optimal for the given examples,
+    // which overflow. It is certainly sub-optimal on other examples.
+    APInt ZerosMask =
+              APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+          OnesMask = APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+    KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+    KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+    ExactZeros.Zero.setAllBits();
+    ExactZeros.One.setAllBits();
+    ExactOnes.Zero.setAllBits();
+    ExactOnes.One.setAllBits();
+
+    ExactZeros.Zero = Exact.Zero & ZerosMask;
+    ExactZeros.One = Exact.One & ZerosMask;
+    ComputedZeros.Zero = Computed.Zero & ZerosMask;
+    ComputedZeros.One = Computed.One & ZerosMask;
+    EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros, {Known1, Known2},
+                            /*CheckOptimality=*/true));
+
+    ExactOnes.Zero = Exact.Zero & OnesMask;
+    ExactOnes.One = Exact.One & OnesMask;
+    ComputedOnes.Zero = Computed.Zero & OnesMask;
+    ComputedOnes.One = Computed.One & OnesMask;
+    EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes, {Known1, Known2},
                             /*CheckOptimality=*/true));
   }
 }
 
+TEST(KnownBitsTest, MulStress) {
+  // Stress test KnownBits::mul on 5 and 6 bits, checking that the result is
+  // correct, even if not optimal.
+  for (unsigned Bits : {5, 6}) {
+    ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+      ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+        KnownBits Computed = KnownBits::mul(Known1, Known2);
+        KnownBits Exact(Bits);
+        Exact.Zero.setAllBits();
+        Exact.One.setAllBits();
+
+        ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+          ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+            APInt Res = N1 * N2;
+            Exact.One &= Res;
+            Exact.Zero &= ~Res;
+          });
+        });
+
+        if (!Exact.hasConflict()) {
+          EXPECT_TRUE(checkResult("mul", Exact, Computed, {Known1, Known2},
+                                  /*CheckOptimality=*/false));
+        }
+      });
+    });
+  }
+}
 } // end anonymous namespace

>From dd73f6c6be688cf49f7fc4551b70ef437f9150a9 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Wed, 30 Oct 2024 18:30:32 +0000
Subject: [PATCH 6/6] KnownBits: fix issues

---
 llvm/lib/Support/KnownBits.cpp | 42 +++++++++++++++++-----------------
 1 file changed, 21 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 44261c9aa7b567..8585c88fb57ae4 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -803,12 +803,7 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
         MinLHS = LHS.isNegative() ? LHS.getMaxValue().abs() : LHS.getMinValue(),
         MinRHS = RHS.isNegative() ? RHS.getMaxValue().abs() : RHS.getMinValue();
 
-  // If MaxProduct doesn't overflow, it implies that MinProduct also won't
-  // overflow. However, if MaxProduct overflows, there is no guarantee on the
-  // MinProduct overflowing.
-  bool HasOverflow;
-  APInt MaxProduct = MaxLHS.umul_ov(MaxRHS, HasOverflow),
-        MinProduct = MinLHS * MinRHS;
+  APInt MaxProduct = MaxLHS * MaxRHS, MinProduct = MinLHS * MinRHS;
 
   if (LHS.isNegative() != RHS.isNegative()) {
     // The unsigned-multiplication wrapped MinProduct and MaxProduct can be
@@ -831,9 +826,9 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
     // A product of M active bits * N active bits results in M + N bits in the
     // result. If either of the operands is a power of two, the result has one
     // less active bit.
-    auto ProdActiveBits = [](const APInt &A, const APInt &B) -> unsigned {
+    auto ProdActiveBits = [](const APInt &A, const APInt &B) {
       if (A.isZero() || B.isZero())
-        return 0;
+        return 0u;
       return A.getActiveBits() + B.getActiveBits() -
              (A.isPowerOf2() || B.isPowerOf2());
     };
@@ -841,26 +836,31 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
     // We want to compute the number of active bits in the difference between
     // the non-wrapped max product and non-wrapped min product, but we want to
     // avoid camputing the non-wrapped max/min product.
-    unsigned ActiveBitsInDiff;
-    if (MinLHS.isZero() && MinRHS.isZero())
-      ActiveBitsInDiff = ProdActiveBits(LHSUnknown, RHSUnknown);
-    else
+    unsigned ActiveBitsInDiff = BitWidth + 1;
+    if (LHSUnknown.isZero()) {
+      ActiveBitsInDiff =
+          ProdActiveBits(MinLHS.isZero() ? LHSUnknown : MinLHS, RHSUnknown);
+    } else if (RHSUnknown.isZero()) {
       ActiveBitsInDiff =
-          ProdActiveBits(MinLHS.isZero() ? LHSUnknown : MinLHS, RHSUnknown) +
           ProdActiveBits(MinRHS.isZero() ? RHSUnknown : MinRHS, LHSUnknown);
-
-    // Checks that A.ugt(B), excluding the degenerate case where A is all-ones
-    // and B is zero.
-    auto UgtCheckCorner = [](const APInt &A, const APInt &B) {
-      return (!A.isAllOnes() || !B.isZero()) && A.ugt(B);
-    };
+    } else if (ProdActiveBits(MinLHS, RHSUnknown) <= BitWidth &&
+               ProdActiveBits(MinRHS, LHSUnknown) <= BitWidth &&
+               ProdActiveBits(LHSUnknown, RHSUnknown) <= BitWidth) {
+      // Slow path, which is seldom hit in practice.
+      // (MinLHS + LHSUnknown) * (MinRHS + RHSUnknown) - (MinLHS * MinRHS)
+      // = MinLHS * RHSUnknown + MinRHS * LHSUnknown + LHSUnknown * RHSUnknown.
+      APInt Res = MinLHS.umul_sat(RHSUnknown)
+                      .uadd_sat(MinRHS.umul_sat(LHSUnknown))
+                      .uadd_sat(LHSUnknown.umul_sat(RHSUnknown));
+      if (!Res.isMaxValue())
+        ActiveBitsInDiff = Res.getActiveBits();
+    }
 
     // We uniformly handle the case where there is no max-overflow, in which
     // case the high zeros and ones are computed optimally, and where there is,
     // but the result shifts at most by BitWidth, in which case the high zeros
     // and ones are not computed optimally.
-    if ((!HasOverflow || ActiveBitsInDiff <= BitWidth) &&
-        UgtCheckCorner(MaxProduct, MinProduct)) {
+    if (ActiveBitsInDiff <= BitWidth && MaxProduct.ugt(MinProduct)) {
       // Set the minimum leading zeros or ones from MaxProduct and MinProduct.
       LeadZ = MaxProduct.countLeadingZeros();
       LeadO = MinProduct.countLeadingOnes();



More information about the llvm-commits mailing list