[llvm] [ConstantRange] Handle `Intrinsic::cttz` and `Intrinsic::ctpop` (PR #67917)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Sun Oct 1 07:19:54 PDT 2023
https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/67917
This patch adds support for cttz and ctpop intrinsics in ConstantRange. It calculates the range in O(1) with the LCP-based method.
Migrated from https://reviews.llvm.org/D153505.
>From b5d134c88a04c524b1d9120a1c1a5dae3722904c Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 1 Oct 2023 22:17:35 +0800
Subject: [PATCH] [ConstantRange] Handle `Intrinsic::cttz` and
`Intrinsic::ctpop`
---
llvm/include/llvm/IR/ConstantRange.h | 7 +
llvm/lib/IR/ConstantRange.cpp | 127 ++++++++++++++++++
.../CorrelatedValuePropagation/range.ll | 54 ++++++++
llvm/unittests/IR/ConstantRangeTest.cpp | 20 +++
4 files changed, 208 insertions(+)
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index ca36732e4e2e8c2..e718e6e7e3403de 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -530,6 +530,13 @@ class [[nodiscard]] ConstantRange {
/// ignoring a possible zero value contained in the input range.
ConstantRange ctlz(bool ZeroIsPoison = false) const;
+ /// Calculate cttz range. If \p ZeroIsPoison is set, the range is computed
+ /// ignoring a possible zero value contained in the input range.
+ ConstantRange cttz(bool ZeroIsPoison = false) const;
+
+ /// Calculate ctpop range.
+ ConstantRange ctpop() const;
+
/// Represents whether an operation on the given constant range is known to
/// always or never overflow.
enum class OverflowResult {
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 3d71b20f7e853e0..f34a2749543c321 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -949,6 +949,8 @@ bool ConstantRange::isIntrinsicSupported(Intrinsic::ID IntrinsicID) {
case Intrinsic::smax:
case Intrinsic::abs:
case Intrinsic::ctlz:
+ case Intrinsic::cttz:
+ case Intrinsic::ctpop:
return true;
default:
return false;
@@ -986,6 +988,15 @@ ConstantRange ConstantRange::intrinsic(Intrinsic::ID IntrinsicID,
assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean");
return Ops[0].ctlz(ZeroIsPoison->getBoolValue());
}
+ case Intrinsic::cttz: {
+ const APInt *ZeroIsPoison = Ops[1].getSingleElement();
+ assert(ZeroIsPoison && "Must be known (immarg)");
+ assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean");
+ return Ops[0].cttz(ZeroIsPoison->getBoolValue());
+ }
+ case Intrinsic::ctpop: {
+ return Ops[0].ctpop();
+ }
default:
assert(!isIntrinsicSupported(IntrinsicID) && "Shouldn't be supported");
llvm_unreachable("Unsupported intrinsic");
@@ -1735,6 +1746,122 @@ ConstantRange ConstantRange::ctlz(bool ZeroIsPoison) const {
return getNonEmpty(APInt(getBitWidth(), getUnsignedMax().countl_zero()),
APInt(getBitWidth(), getUnsignedMin().countl_zero() + 1));
}
+static ConstantRange getUnsignedCountTrailingZerosRange(const APInt &Lower,
+ const APInt &Upper) {
+ assert(Lower.ule(Upper));
+ unsigned BitWidth = Lower.getBitWidth();
+ if (Lower == Upper)
+ return ConstantRange::getEmpty(BitWidth);
+ if (Lower + 1 == Upper)
+ return ConstantRange(APInt(BitWidth, Lower.countr_zero()));
+ if (Lower.isZero())
+ return ConstantRange(APInt::getZero(BitWidth),
+ APInt(BitWidth, BitWidth + 1));
+
+ // Calculate longest common prefix.
+ unsigned LCPLength = (Lower ^ (Upper - 1)).countl_zero();
+ // If Lower is {LCP, 000...}, the maximum is Lower.countr_zero().
+ // Otherwise, the maximum is BitWidth - LCPLength - 1 ({LCP, 100...}).
+ return ConstantRange(
+ APInt::getZero(BitWidth),
+ APInt(BitWidth, std::max(BitWidth - LCPLength, Lower.countr_zero() + 1)));
+}
+
+ConstantRange ConstantRange::cttz(bool ZeroIsPoison) const {
+ if (isEmptySet())
+ return getEmpty();
+
+ APInt Zero = APInt::getZero(getBitWidth());
+
+ if (ZeroIsPoison && contains(Zero)) {
+ // ZeroIsPoison is set, and zero is contained. We discern three cases, in
+ // which a zero can appear:
+ // 1) Lower is zero, handling cases of kind [0, 1), [0, 2), etc.
+ // 2) Upper is zero, wrapped set, handling cases of kind [3, 0], etc.
+ // 3) Zero contained in a wrapped set, e.g., [3, 2), [3, 1), etc.
+
+ if (getLower().isZero()) {
+ if ((getUpper() - 1).isZero()) {
+ // We have in input interval of kind [0, 1). In this case we cannot
+ // really help but return empty-set.
+ return getEmpty();
+ }
+
+ // Compute the resulting range by excluding zero from Lower.
+ return getUnsignedCountTrailingZerosRange(getLower() + 1, getUpper());
+ } else if ((getUpper() - 1).isZero()) {
+ // Compute the resulting range by excluding zero from Upper.
+ return ConstantRange(
+ Zero, APInt(getBitWidth(),
+ (getUnsignedMax() - getLower() + 1).logBase2() + 1));
+ } else {
+ ConstantRange CR1(
+ Zero, APInt(getBitWidth(),
+ (getUnsignedMax() - getLower() + 1).logBase2() + 1));
+ ConstantRange CR2 = getUnsignedCountTrailingZerosRange(
+ APInt(getBitWidth(), 1), getUpper());
+ return CR1.unionWith(CR2);
+ }
+ }
+
+ if (isFullSet()) {
+ return getNonEmpty(Zero, APInt(getBitWidth(), getBitWidth() + 1));
+ }
+ if (!isUpperWrapped()) {
+ return getUnsignedCountTrailingZerosRange(getLower(), getUpper());
+ }
+ ConstantRange CR1(
+ Zero,
+ APInt(getBitWidth(), (getUnsignedMax() - getLower() + 1).logBase2() + 1));
+ ConstantRange CR2 = getUnsignedCountTrailingZerosRange(Zero, getUpper());
+ return CR1.unionWith(CR2);
+}
+
+static ConstantRange getUnsignedPopCountRange(const APInt &Lower,
+ const APInt &Upper) {
+ assert(Lower.ule(Upper));
+ unsigned BitWidth = Lower.getBitWidth();
+ if (Lower == Upper)
+ return ConstantRange::getEmpty(BitWidth);
+ if (Lower + 1 == Upper)
+ return ConstantRange(APInt(BitWidth, Lower.popcount()));
+
+ APInt Max = Upper - 1;
+ // Calculate longest common prefix.
+ unsigned LCPLength = (Lower ^ Max).countl_zero();
+ unsigned LCPPopCount = Lower.getHiBits(LCPLength).popcount();
+ // If Lower is {LCP, 000...}, the minimum is the popcount of LCP.
+ // Otherwise, the minimum is the popcount of LCP + 1.
+ unsigned MinBits =
+ LCPPopCount + (Lower.countr_zero() < BitWidth - LCPLength ? 1 : 0);
+ // If Max is {LCP, 111...}, the maximum is the popcount of LCP + (BitWidth -
+ // length of LCP).
+ // Otherwise, the minimum is the popcount of LCP + (BitWidth -
+ // length of LCP - 1).
+ unsigned MaxBits = LCPPopCount + (BitWidth - LCPLength) +
+ (Max.countr_one() >= BitWidth - LCPLength ? 1 : 0);
+ return ConstantRange(APInt(BitWidth, MinBits), APInt(BitWidth, MaxBits));
+}
+
+ConstantRange ConstantRange::ctpop() const {
+ if (isEmptySet())
+ return getEmpty();
+
+ unsigned BitWidth = getBitWidth();
+ APInt Zero = APInt::getZero(BitWidth);
+ if (isFullSet()) {
+ return getNonEmpty(Zero, APInt(BitWidth, BitWidth + 1));
+ }
+ if (!isUpperWrapped()) {
+ return getUnsignedPopCountRange(getLower(), getUpper());
+ }
+ ConstantRange CR1 = ConstantRange(
+ APInt(BitWidth,
+ BitWidth - (getUnsignedMax() - getLower() + 1).logBase2()),
+ APInt(BitWidth, BitWidth + 1)); // [lower, intmax]
+ ConstantRange CR2 = getUnsignedPopCountRange(Zero, getUpper()); // [0, upper)
+ return CR1.unionWith(CR2);
+}
ConstantRange::OverflowResult ConstantRange::unsignedAddMayOverflow(
const ConstantRange &Other) const {
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll
index 7e89f864c8110ee..182a0bbef255de8 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll
@@ -1010,6 +1010,60 @@ else:
ret i1 %res2
}
+define i1 @cttz_fold(i16 %x) {
+; CHECK-LABEL: @cttz_fold(
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256
+; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]]
+; CHECK: if:
+; CHECK-NEXT: [[CTTZ:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true)
+; CHECK-NEXT: ret i1 false
+; CHECK: else:
+; CHECK-NEXT: [[CTTZ2:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true)
+; CHECK-NEXT: [[RES2:%.*]] = icmp ult i16 [[CTTZ2]], 8
+; CHECK-NEXT: ret i1 [[RES2]]
+;
+ %cmp = icmp ult i16 %x, 256
+ br i1 %cmp, label %if, label %else
+
+if:
+ %cttz = call i16 @llvm.cttz.i16(i16 %x, i1 true)
+ %res = icmp uge i16 %cttz, 8
+ ret i1 %res
+
+else:
+ %cttz2 = call i16 @llvm.cttz.i16(i16 %x, i1 true)
+ %res2 = icmp ult i16 %cttz2, 8
+ ret i1 %res2
+}
+
+define i1 @ctpop_fold(i16 %x) {
+; CHECK-LABEL: @ctpop_fold(
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256
+; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]]
+; CHECK: if:
+; CHECK-NEXT: [[CTPOP:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]])
+; CHECK-NEXT: ret i1 true
+; CHECK: else:
+; CHECK-NEXT: [[CTPOP2:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]])
+; CHECK-NEXT: [[RES2:%.*]] = icmp ugt i16 [[CTPOP2]], 8
+; CHECK-NEXT: ret i1 [[RES2]]
+;
+ %cmp = icmp ult i16 %x, 256
+ br i1 %cmp, label %if, label %else
+
+if:
+ %ctpop = call i16 @llvm.ctpop.i16(i16 %x)
+ %res = icmp ule i16 %ctpop, 8
+ ret i1 %res
+
+else:
+ %ctpop2 = call i16 @llvm.ctpop.i16(i16 %x)
+ %res2 = icmp ugt i16 %ctpop2, 8
+ ret i1 %res2
+}
+
declare i16 @llvm.ctlz.i16(i16, i1)
+declare i16 @llvm.cttz.i16(i16, i1)
+declare i16 @llvm.ctpop.i16(i16)
declare i16 @llvm.abs.i16(i16, i1)
declare void @llvm.assume(i1)
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 1cb358a26062ca5..e505af5d3275ef2 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -2438,6 +2438,26 @@ TEST_F(ConstantRangeTest, Ctlz) {
});
}
+TEST_F(ConstantRangeTest, Cttz) {
+ TestUnaryOpExhaustive(
+ [](const ConstantRange &CR) { return CR.cttz(); },
+ [](const APInt &N) { return APInt(N.getBitWidth(), N.countr_zero()); });
+
+ TestUnaryOpExhaustive(
+ [](const ConstantRange &CR) { return CR.cttz(/*ZeroIsPoison=*/true); },
+ [](const APInt &N) -> std::optional<APInt> {
+ if (N.isZero())
+ return std::nullopt;
+ return APInt(N.getBitWidth(), N.countr_zero());
+ });
+}
+
+TEST_F(ConstantRangeTest, Ctpop) {
+ TestUnaryOpExhaustive(
+ [](const ConstantRange &CR) { return CR.ctpop(); },
+ [](const APInt &N) { return APInt(N.getBitWidth(), N.popcount()); });
+}
+
TEST_F(ConstantRangeTest, castOps) {
ConstantRange A(APInt(16, 66), APInt(16, 128));
ConstantRange FpToI8 = A.castOp(Instruction::FPToSI, 8);
More information about the llvm-commits
mailing list