[llvm] [ConstantRange] Handle `Intrinsic::cttz` (PR #67917)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 6 00:33:24 PST 2023


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/67917

>From ddceeffc34acbeb5d37559f50f7570ea8e45745b Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 4 Nov 2023 22:16:38 +0800
Subject: [PATCH 1/2] [ConstantRange] Handle `Intrinsic::cttz`

---
 llvm/include/llvm/IR/ConstantRange.h          |  4 +
 llvm/lib/IR/ConstantRange.cpp                 | 76 +++++++++++++++++++
 .../CorrelatedValuePropagation/range.ll       |  3 +-
 llvm/unittests/IR/ConstantRangeTest.cpp       | 14 ++++
 4 files changed, 95 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index 017f1f36d8a663e..e718e6e7e3403de 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -530,6 +530,10 @@ class [[nodiscard]] ConstantRange {
   /// ignoring a possible zero value contained in the input range.
   ConstantRange ctlz(bool ZeroIsPoison = false) const;
 
+  /// Calculate cttz range. If \p ZeroIsPoison is set, the range is computed
+  /// ignoring a possible zero value contained in the input range.
+  ConstantRange cttz(bool ZeroIsPoison = false) const;
+
   /// Calculate ctpop range.
   ConstantRange ctpop() const;
 
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index af93b69a11c4dab..bce91508c00f219 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -949,6 +949,7 @@ bool ConstantRange::isIntrinsicSupported(Intrinsic::ID IntrinsicID) {
   case Intrinsic::smax:
   case Intrinsic::abs:
   case Intrinsic::ctlz:
+  case Intrinsic::cttz:
   case Intrinsic::ctpop:
     return true;
   default:
@@ -987,6 +988,12 @@ ConstantRange ConstantRange::intrinsic(Intrinsic::ID IntrinsicID,
     assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean");
     return Ops[0].ctlz(ZeroIsPoison->getBoolValue());
   }
+  case Intrinsic::cttz: {
+    const APInt *ZeroIsPoison = Ops[1].getSingleElement();
+    assert(ZeroIsPoison && "Must be known (immarg)");
+    assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean");
+    return Ops[0].cttz(ZeroIsPoison->getBoolValue());
+  }
   case Intrinsic::ctpop:
     return Ops[0].ctpop();
   default:
@@ -1739,6 +1746,75 @@ ConstantRange ConstantRange::ctlz(bool ZeroIsPoison) const {
                      APInt(getBitWidth(), getUnsignedMin().countl_zero() + 1));
 }
 
+static ConstantRange getUnsignedCountTrailingZerosRange(const APInt &Lower,
+                                                        const APInt &Upper) {
+  assert(!ConstantRange(Lower, Upper).isWrappedSet() &&
+         "Unexpected wrapped set.");
+  assert(Lower != Upper && "Unexpected empty set.");
+  unsigned BitWidth = Lower.getBitWidth();
+  if (Lower + 1 == Upper)
+    return ConstantRange(APInt(BitWidth, Lower.countr_zero()));
+  if (Lower.isZero())
+    return ConstantRange(APInt::getZero(BitWidth),
+                         APInt(BitWidth, BitWidth + 1));
+
+  // Calculate longest common prefix.
+  unsigned LCPLength = (Lower ^ (Upper - 1)).countl_zero();
+  // If Lower is {LCP, 000...}, the maximum is Lower.countr_zero().
+  // Otherwise, the maximum is BitWidth - LCPLength - 1 ({LCP, 100...}).
+  return ConstantRange(
+      APInt::getZero(BitWidth),
+      APInt(BitWidth,
+            std::max(BitWidth - LCPLength - 1, Lower.countr_zero()) + 1));
+}
+
+ConstantRange ConstantRange::cttz(bool ZeroIsPoison) const {
+  if (isEmptySet())
+    return getEmpty();
+
+  APInt Zero = APInt::getZero(getBitWidth());
+
+  if (ZeroIsPoison && contains(Zero)) {
+    // ZeroIsPoison is set, and zero is contained. We discern three cases, in
+    // which a zero can appear:
+    // 1) Lower is zero, handling cases of kind [0, 1), [0, 2), etc.
+    // 2) Upper is zero, wrapped set, handling cases of kind [3, 0], etc.
+    // 3) Zero contained in a wrapped set, e.g., [3, 2), [3, 1), etc.
+
+    if (getLower().isZero()) {
+      if (getUpper() == 1) {
+        // We have in input interval of kind [0, 1). In this case we cannot
+        // really help but return empty-set.
+        return getEmpty();
+      }
+
+      // Compute the resulting range by excluding zero from Lower.
+      return getUnsignedCountTrailingZerosRange(APInt(getBitWidth(), 1),
+                                                getUpper());
+    } else if (getUpper() == 1) {
+      // Compute the resulting range by excluding zero from Upper.
+      return getUnsignedCountTrailingZerosRange(getLower(), Zero);
+    } else {
+      ConstantRange CR1 = getUnsignedCountTrailingZerosRange(getLower(), Zero);
+      ConstantRange CR2 = getUnsignedCountTrailingZerosRange(
+          APInt(getBitWidth(), 1), getUpper());
+      return CR1.unionWith(CR2);
+    }
+  }
+
+  if (isFullSet())
+    return getNonEmpty(Zero, APInt(getBitWidth(), getBitWidth() + 1));
+  if (!isWrappedSet())
+    return getUnsignedCountTrailingZerosRange(getLower(), getUpper());
+  // The range is wrapped. We decompose it into two ranges, [0, Upper) and
+  // [Lower, 0).
+  // Handle [Lower, 0)
+  ConstantRange CR1 = getUnsignedCountTrailingZerosRange(getLower(), Zero);
+  // Handle [0, Upper)
+  ConstantRange CR2 = getUnsignedCountTrailingZerosRange(Zero, getUpper());
+  return CR1.unionWith(CR2);
+}
+
 static ConstantRange getUnsignedPopCountRange(const APInt &Lower,
                                               const APInt &Upper) {
   assert(!ConstantRange(Lower, Upper).isWrappedSet() &&
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll
index cf85753e59c8083..17d979dcff23d07 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll
@@ -1016,8 +1016,7 @@ define i1 @cttz_fold(i16 %x) {
 ; CHECK-NEXT:    br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]]
 ; CHECK:       if:
 ; CHECK-NEXT:    [[CTTZ:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true)
-; CHECK-NEXT:    [[RES:%.*]] = icmp uge i16 [[CTTZ]], 8
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    ret i1 false
 ; CHECK:       else:
 ; CHECK-NEXT:    ret i1 false
 ;
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 12facfb22fb3c73..e505af5d3275ef2 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -2438,6 +2438,20 @@ TEST_F(ConstantRangeTest, Ctlz) {
       });
 }
 
+TEST_F(ConstantRangeTest, Cttz) {
+  TestUnaryOpExhaustive(
+      [](const ConstantRange &CR) { return CR.cttz(); },
+      [](const APInt &N) { return APInt(N.getBitWidth(), N.countr_zero()); });
+
+  TestUnaryOpExhaustive(
+      [](const ConstantRange &CR) { return CR.cttz(/*ZeroIsPoison=*/true); },
+      [](const APInt &N) -> std::optional<APInt> {
+        if (N.isZero())
+          return std::nullopt;
+        return APInt(N.getBitWidth(), N.countr_zero());
+      });
+}
+
 TEST_F(ConstantRangeTest, Ctpop) {
   TestUnaryOpExhaustive(
       [](const ConstantRange &CR) { return CR.ctpop(); },

>From 9d16764173ace656dd136c822f376598641aca82 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 6 Nov 2023 00:51:16 +0800
Subject: [PATCH 2/2] fixup! [ConstantRange] Handle `Intrinsic::cttz`

---
 llvm/lib/IR/ConstantRange.cpp | 29 ++++++++++++++---------------
 1 file changed, 14 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index bce91508c00f219..cbb64b299e648e4 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1772,8 +1772,8 @@ ConstantRange ConstantRange::cttz(bool ZeroIsPoison) const {
   if (isEmptySet())
     return getEmpty();
 
-  APInt Zero = APInt::getZero(getBitWidth());
-
+  unsigned BitWidth = getBitWidth();
+  APInt Zero = APInt::getZero(BitWidth);
   if (ZeroIsPoison && contains(Zero)) {
     // ZeroIsPoison is set, and zero is contained. We discern three cases, in
     // which a zero can appear:
@@ -1781,37 +1781,36 @@ ConstantRange ConstantRange::cttz(bool ZeroIsPoison) const {
     // 2) Upper is zero, wrapped set, handling cases of kind [3, 0], etc.
     // 3) Zero contained in a wrapped set, e.g., [3, 2), [3, 1), etc.
 
-    if (getLower().isZero()) {
-      if (getUpper() == 1) {
+    if (Lower.isZero()) {
+      if (Upper == 1) {
         // We have in input interval of kind [0, 1). In this case we cannot
         // really help but return empty-set.
         return getEmpty();
       }
 
       // Compute the resulting range by excluding zero from Lower.
-      return getUnsignedCountTrailingZerosRange(APInt(getBitWidth(), 1),
-                                                getUpper());
-    } else if (getUpper() == 1) {
+      return getUnsignedCountTrailingZerosRange(APInt(BitWidth, 1), Upper);
+    } else if (Upper == 1) {
       // Compute the resulting range by excluding zero from Upper.
-      return getUnsignedCountTrailingZerosRange(getLower(), Zero);
+      return getUnsignedCountTrailingZerosRange(Lower, Zero);
     } else {
-      ConstantRange CR1 = getUnsignedCountTrailingZerosRange(getLower(), Zero);
-      ConstantRange CR2 = getUnsignedCountTrailingZerosRange(
-          APInt(getBitWidth(), 1), getUpper());
+      ConstantRange CR1 = getUnsignedCountTrailingZerosRange(Lower, Zero);
+      ConstantRange CR2 =
+          getUnsignedCountTrailingZerosRange(APInt(BitWidth, 1), Upper);
       return CR1.unionWith(CR2);
     }
   }
 
   if (isFullSet())
-    return getNonEmpty(Zero, APInt(getBitWidth(), getBitWidth() + 1));
+    return getNonEmpty(Zero, APInt(BitWidth, BitWidth + 1));
   if (!isWrappedSet())
-    return getUnsignedCountTrailingZerosRange(getLower(), getUpper());
+    return getUnsignedCountTrailingZerosRange(Lower, Upper);
   // The range is wrapped. We decompose it into two ranges, [0, Upper) and
   // [Lower, 0).
   // Handle [Lower, 0)
-  ConstantRange CR1 = getUnsignedCountTrailingZerosRange(getLower(), Zero);
+  ConstantRange CR1 = getUnsignedCountTrailingZerosRange(Lower, Zero);
   // Handle [0, Upper)
-  ConstantRange CR2 = getUnsignedCountTrailingZerosRange(Zero, getUpper());
+  ConstantRange CR2 = getUnsignedCountTrailingZerosRange(Zero, Upper);
   return CR1.unionWith(CR2);
 }
 



More information about the llvm-commits mailing list