[llvm] [ConstantRange] Handle `Intrinsic::cttz` and `Intrinsic::ctpop` (PR #67917)

via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 1 23:28:35 PDT 2023


================
@@ -1735,6 +1746,122 @@ ConstantRange ConstantRange::ctlz(bool ZeroIsPoison) const {
   return getNonEmpty(APInt(getBitWidth(), getUnsignedMax().countl_zero()),
                      APInt(getBitWidth(), getUnsignedMin().countl_zero() + 1));
 }
+static ConstantRange getUnsignedCountTrailingZerosRange(const APInt &Lower,
+                                                        const APInt &Upper) {
+  assert(Lower.ule(Upper));
+  unsigned BitWidth = Lower.getBitWidth();
+  if (Lower == Upper)
+    return ConstantRange::getEmpty(BitWidth);
+  if (Lower + 1 == Upper)
+    return ConstantRange(APInt(BitWidth, Lower.countr_zero()));
+  if (Lower.isZero())
+    return ConstantRange(APInt::getZero(BitWidth),
+                         APInt(BitWidth, BitWidth + 1));
+
+  // Calculate longest common prefix.
+  unsigned LCPLength = (Lower ^ (Upper - 1)).countl_zero();
+  // If Lower is {LCP, 000...}, the maximum is Lower.countr_zero().
+  // Otherwise, the maximum is BitWidth - LCPLength - 1 ({LCP, 100...}).
+  return ConstantRange(
+      APInt::getZero(BitWidth),
+      APInt(BitWidth, std::max(BitWidth - LCPLength, Lower.countr_zero() + 1)));
+}
+
+ConstantRange ConstantRange::cttz(bool ZeroIsPoison) const {
+  if (isEmptySet())
+    return getEmpty();
+
+  APInt Zero = APInt::getZero(getBitWidth());
+
+  if (ZeroIsPoison && contains(Zero)) {
+    // ZeroIsPoison is set, and zero is contained. We discern three cases, in
+    // which a zero can appear:
+    // 1) Lower is zero, handling cases of kind [0, 1), [0, 2), etc.
+    // 2) Upper is zero, wrapped set, handling cases of kind [3, 0], etc.
+    // 3) Zero contained in a wrapped set, e.g., [3, 2), [3, 1), etc.
+
+    if (getLower().isZero()) {
+      if ((getUpper() - 1).isZero()) {
+        // We have in input interval of kind [0, 1). In this case we cannot
+        // really help but return empty-set.
+        return getEmpty();
+      }
+
+      // Compute the resulting range by excluding zero from Lower.
+      return getUnsignedCountTrailingZerosRange(getLower() + 1, getUpper());
+    } else if ((getUpper() - 1).isZero()) {
+      // Compute the resulting range by excluding zero from Upper.
+      return ConstantRange(
+          Zero, APInt(getBitWidth(),
+                      (getUnsignedMax() - getLower() + 1).logBase2() + 1));
+    } else {
+      ConstantRange CR1(
+          Zero, APInt(getBitWidth(),
+                      (getUnsignedMax() - getLower() + 1).logBase2() + 1));
+      ConstantRange CR2 = getUnsignedCountTrailingZerosRange(
+          APInt(getBitWidth(), 1), getUpper());
+      return CR1.unionWith(CR2);
+    }
+  }
+
+  if (isFullSet()) {
+    return getNonEmpty(Zero, APInt(getBitWidth(), getBitWidth() + 1));
+  }
+  if (!isUpperWrapped()) {
+    return getUnsignedCountTrailingZerosRange(getLower(), getUpper());
+  }
+  ConstantRange CR1(
+      Zero,
+      APInt(getBitWidth(), (getUnsignedMax() - getLower() + 1).logBase2() + 1));
+  ConstantRange CR2 = getUnsignedCountTrailingZerosRange(Zero, getUpper());
+  return CR1.unionWith(CR2);
+}
+
+static ConstantRange getUnsignedPopCountRange(const APInt &Lower,
+                                              const APInt &Upper) {
+  assert(Lower.ule(Upper));
+  unsigned BitWidth = Lower.getBitWidth();
+  if (Lower == Upper)
+    return ConstantRange::getEmpty(BitWidth);
+  if (Lower + 1 == Upper)
+    return ConstantRange(APInt(BitWidth, Lower.popcount()));
+
+  APInt Max = Upper - 1;
+  // Calculate longest common prefix.
+  unsigned LCPLength = (Lower ^ Max).countl_zero();
+  unsigned LCPPopCount = Lower.getHiBits(LCPLength).popcount();
+  // If Lower is {LCP, 000...}, the minimum is the popcount of LCP.
+  // Otherwise, the minimum is the popcount of LCP + 1.
+  unsigned MinBits =
+      LCPPopCount + (Lower.countr_zero() < BitWidth - LCPLength ? 1 : 0);
+  // If Max is {LCP, 111...}, the maximum is the popcount of LCP + (BitWidth -
+  // length of LCP).
+  // Otherwise, the minimum is the popcount of LCP + (BitWidth -
+  // length of LCP - 1).
+  unsigned MaxBits = LCPPopCount + (BitWidth - LCPLength) +
+                     (Max.countr_one() >= BitWidth - LCPLength ? 1 : 0);
----------------
goldsteinn wrote:

This is fairly inprecise.
I'm not sure if there is some clean bitwise trick, but think a precise log2(bitwidth) approach would probably be better.

i.e
```
APInt Max = APInt::getZero(BitWidth);
APInt Min = APInt::getZero(BitWidth);
for(unsigned I = 0; I < BitWidth; ++I) {
   APInt NextMax = Max;
   NextMax.setBit(I);
   if(NextMax.uge(Upper)) break;
}
MaxBits = Max.popcount();
for(unsigned I = BitWidth; I != 0; --I) {
   if(Min.uge(Lower)) break;
   APInt NextMin = Min;
   NextMin.setBit(I - 1);
   if(NextMin.uge(Upper)) continue;
   Min = NextMin;
}
```

Obviously fairly optimizable, but don't think it will be an serious bottleneck and think its worth getting the precision.

https://github.com/llvm/llvm-project/pull/67917


More information about the llvm-commits mailing list