[llvm] [SCCP] Add support for trunc nuw range. (PR #152990)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 11 03:41:09 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-function-specialization

Author: Andreas Jonson (andjo403)

<details>
<summary>Changes</summary>

proof: https://alive2.llvm.org/ce/z/_7PVxq

part of changes needed for move of https://github.com/llvm/llvm-project/pull/151961 to SCCP

---
Full diff: https://github.com/llvm/llvm-project/pull/152990.diff


5 Files Affected:

- (modified) llvm/include/llvm/IR/ConstantRange.h (+3-2) 
- (modified) llvm/lib/IR/ConstantRange.cpp (+30-11) 
- (modified) llvm/lib/Transforms/Utils/SCCPSolver.cpp (+6-2) 
- (modified) llvm/test/Transforms/SCCP/conditions-ranges.ll (+41) 
- (modified) llvm/unittests/IR/ConstantRangeTest.cpp (+39) 


``````````diff
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index 9a6a9db65688a..4b2fda364fdf4 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -380,8 +380,9 @@ class [[nodiscard]] ConstantRange {
   /// Return a new range in the specified integer type, which must be
   /// strictly smaller than the current type.  The returned range will
   /// correspond to the possible range of values if the source range had been
-  /// truncated to the specified type.
-  LLVM_ABI ConstantRange truncate(uint32_t BitWidth) const;
+  /// truncated to the specified type with wrap type \p NoWrapKind.
+  LLVM_ABI ConstantRange truncate(uint32_t BitWidth,
+                                  unsigned NoWrapKind = 0) const;
 
   /// Make this range have the bit width given by \p BitWidth. The
   /// value is zero extended, truncated, or left alone to make it that width.
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 2fcdbcc6a3db2..b1ea4bbd5ae6b 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -872,7 +872,8 @@ ConstantRange ConstantRange::signExtend(uint32_t DstTySize) const {
   return ConstantRange(Lower.sext(DstTySize), Upper.sext(DstTySize));
 }
 
-ConstantRange ConstantRange::truncate(uint32_t DstTySize) const {
+ConstantRange ConstantRange::truncate(uint32_t DstTySize,
+                                      unsigned NoWrapKind) const {
   assert(getBitWidth() > DstTySize && "Not a value truncation");
   if (isEmptySet())
     return getEmpty(DstTySize);
@@ -886,22 +887,36 @@ ConstantRange ConstantRange::truncate(uint32_t DstTySize) const {
   // We use the non-wrapped set code to analyze the [Lower, MaxValue) part, and
   // then we do the union with [MaxValue, Upper)
   if (isUpperWrapped()) {
-    // If Upper is greater than or equal to MaxValue(DstTy), it covers the whole
-    // truncated range.
-    if (Upper.getActiveBits() > DstTySize || Upper.countr_one() == DstTySize)
+    // If Upper is greater than MaxValue(DstTy), it covers the whole truncated
+    // range.
+    if (Upper.getActiveBits() > DstTySize)
       return getFull(DstTySize);
 
-    Union = ConstantRange(APInt::getMaxValue(DstTySize),Upper.trunc(DstTySize));
-    UpperDiv.setAllBits();
-
-    // Union covers the MaxValue case, so return if the remaining range is just
-    // MaxValue(DstTy).
-    if (LowerDiv == UpperDiv)
-      return Union;
+    // For nuw the two parts is: [0, Upper) \/ [Lower, MaxValue(DstTy) + 1]
+    if (NoWrapKind & TruncInst::NoUnsignedWrap) {
+      Union = ConstantRange(APInt::getZero(DstTySize), Upper.trunc(DstTySize));
+      UpperDiv = APInt::getOneBitSet(getBitWidth(), DstTySize);
+    } else {
+      // If Upper is equal to MaxValue(DstTy), it covers the whole truncated
+      // range.
+      if (Upper.countr_one() == DstTySize)
+        return getFull(DstTySize);
+      Union =
+          ConstantRange(APInt::getMaxValue(DstTySize), Upper.trunc(DstTySize));
+      UpperDiv.setAllBits();
+      // Union covers the MaxValue case, so return if the remaining range is
+      // just MaxValue(DstTy).
+      if (LowerDiv == UpperDiv)
+        return Union;
+    }
   }
 
   // Chop off the most significant bits that are past the destination bitwidth.
   if (LowerDiv.getActiveBits() > DstTySize) {
+    // For trunc nuw if LowerDiv is greater than MaxValue(DstTy), the range is
+    // outside the whole truncated range.
+    if (NoWrapKind & TruncInst::NoUnsignedWrap)
+      return Union;
     // Mask to just the signficant bits and subtract from LowerDiv/UpperDiv.
     APInt Adjust = LowerDiv & APInt::getBitsSetFrom(getBitWidth(), DstTySize);
     LowerDiv -= Adjust;
@@ -913,6 +928,10 @@ ConstantRange ConstantRange::truncate(uint32_t DstTySize) const {
     return ConstantRange(LowerDiv.trunc(DstTySize),
                          UpperDiv.trunc(DstTySize)).unionWith(Union);
 
+  if (!LowerDiv.isZero() && NoWrapKind & TruncInst::NoUnsignedWrap)
+    return ConstantRange(LowerDiv.trunc(DstTySize), APInt::getZero(DstTySize))
+        .unionWith(Union);
+
   // The truncated value wraps around. Check if we can do better than fullset.
   if (UpperDivWidth == DstTySize + 1) {
     // Clear the MSB so that UpperDiv wraps around.
diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
index b78c7022b9be0..a03194ded5746 100644
--- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp
+++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
@@ -1433,8 +1433,12 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) {
         OpSt.asConstantRange(I.getSrcTy(), /*UndefAllowed=*/false);
 
     Type *DestTy = I.getDestTy();
-    ConstantRange Res =
-        OpRange.castOp(I.getOpcode(), DestTy->getScalarSizeInBits());
+    ConstantRange Res = ConstantRange::getEmpty(DestTy->getScalarSizeInBits());
+    if (auto *Trunc = dyn_cast<TruncInst>(&I))
+      Res = OpRange.truncate(DestTy->getScalarSizeInBits(),
+                             Trunc->getNoWrapKind());
+    else
+      Res = OpRange.castOp(I.getOpcode(), DestTy->getScalarSizeInBits());
     mergeInValue(LV, &I, ValueLatticeElement::getRange(Res));
   } else
     markOverdefined(&I);
diff --git a/llvm/test/Transforms/SCCP/conditions-ranges.ll b/llvm/test/Transforms/SCCP/conditions-ranges.ll
index bb3764160f724..c2270287c7170 100644
--- a/llvm/test/Transforms/SCCP/conditions-ranges.ll
+++ b/llvm/test/Transforms/SCCP/conditions-ranges.ll
@@ -1436,6 +1436,47 @@ if.end:
   ret i32 0
 }
 
+define void @trunc_nuw_i1_dominating_icmp_ne_0(i8 %x) {
+; CHECK-LABEL: @trunc_nuw_i1_dominating_icmp_ne_0(
+; CHECK-NEXT:    [[ICMP:%.*]] = icmp ne i8 [[X:%.*]], 0
+; CHECK-NEXT:    br i1 [[ICMP]], label [[BB1:%.*]], label [[BB2:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @use(i1 true)
+; CHECK-NEXT:    ret void
+; CHECK:       bb2:
+; CHECK-NEXT:    ret void
+;
+  %icmp = icmp ne i8 %x, 0
+  br i1 %icmp, label %bb1, label %bb2
+bb1:
+  %c1 = trunc nuw i8 %x to i1
+  call void @use(i1 %c1)
+  ret void
+bb2:
+  ret void
+}
+
+define void @neg_trunc_i1_dominating_icmp_ne_0(i8 %x) {
+; CHECK-LABEL: @neg_trunc_i1_dominating_icmp_ne_0(
+; CHECK-NEXT:    [[ICMP:%.*]] = icmp ne i8 [[X:%.*]], 0
+; CHECK-NEXT:    br i1 [[ICMP]], label [[BB1:%.*]], label [[BB2:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    [[C1:%.*]] = trunc i8 [[X]] to i1
+; CHECK-NEXT:    call void @use(i1 [[C1]])
+; CHECK-NEXT:    ret void
+; CHECK:       bb2:
+; CHECK-NEXT:    ret void
+;
+  %icmp = icmp ne i8 %x, 0
+  br i1 %icmp, label %bb1, label %bb2
+bb1:
+  %c1 = trunc i8 %x to i1
+  call void @use(i1 %c1)
+  ret void
+bb2:
+  ret void
+}
+
 define i1 @ptr_icmp_data_layout() {
 ; CHECK-LABEL: @ptr_icmp_data_layout(
 ; CHECK-NEXT:    ret i1 false
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index bcb5d498c8cb9..6bb6903bae703 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -451,6 +451,45 @@ TEST_F(ConstantRangeTest, Trunc) {
   EXPECT_EQ(SevenOne.truncate(2), ConstantRange(APInt(2, 3), APInt(2, 1)));
 }
 
+TEST_F(ConstantRangeTest, TruncNuw) {
+  auto Range = [](unsigned NumBits, unsigned Lower, unsigned Upper) {
+    return ConstantRange(APInt(NumBits, Lower), APInt(NumBits, Upper));
+  };
+  // trunc([0, 4), 3->2) = full
+  EXPECT_TRUE(
+      Range(3, 0, 4).truncate(2, TruncInst::NoUnsignedWrap).isFullSet());
+  // trunc([0, 3), 3->2) = [0, 3)
+  EXPECT_EQ(Range(3, 0, 3).truncate(2, TruncInst::NoUnsignedWrap),
+            Range(2, 0, 3));
+  // trunc([1, 3), 3->2) = [1, 3)
+  EXPECT_EQ(Range(3, 1, 3).truncate(2, TruncInst::NoUnsignedWrap),
+            Range(2, 1, 3));
+  // trunc([1, 5), 3->2) = [1, 0)
+  EXPECT_EQ(Range(3, 1, 5).truncate(2, TruncInst::NoUnsignedWrap),
+            Range(2, 1, 0));
+  // trunc([4, 7), 3->2) = empty
+  EXPECT_TRUE(
+      Range(3, 4, 7).truncate(2, TruncInst::NoUnsignedWrap).isEmptySet());
+  // trunc([4, 0), 3->2) = empty
+  EXPECT_TRUE(
+      Range(3, 4, 0).truncate(2, TruncInst::NoUnsignedWrap).isEmptySet());
+  // trunc([4, 1), 3->2) = [0, 1)
+  EXPECT_EQ(Range(3, 4, 1).truncate(2, TruncInst::NoUnsignedWrap),
+            Range(2, 0, 1));
+  // trunc([3, 1), 3->2) = [3, 1)
+  EXPECT_EQ(Range(3, 3, 1).truncate(2, TruncInst::NoUnsignedWrap),
+            Range(2, 3, 1));
+  // trunc([3, 0), 3->2) = [3, 0)
+  EXPECT_EQ(Range(3, 3, 0).truncate(2, TruncInst::NoUnsignedWrap),
+            Range(2, 3, 0));
+  // trunc([1, 0), 2->1) = [1, 0)
+  EXPECT_EQ(Range(2, 1, 0).truncate(1, TruncInst::NoUnsignedWrap),
+            Range(1, 1, 0));
+  // trunc([2, 1), 2->1) = [0, 1)
+  EXPECT_EQ(Range(2, 2, 1).truncate(1, TruncInst::NoUnsignedWrap),
+            Range(1, 0, 1));
+}
+
 TEST_F(ConstantRangeTest, ZExt) {
   ConstantRange ZFull = Full.zeroExtend(20);
   ConstantRange ZEmpty = Empty.zeroExtend(20);

``````````

</details>


https://github.com/llvm/llvm-project/pull/152990


More information about the llvm-commits mailing list