[llvm] Add ConstantRangeList::subtract(ConstantRange) (PR #97093)

Haopeng Liu via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 28 12:58:49 PDT 2024


https://github.com/haopliu updated https://github.com/llvm/llvm-project/pull/97093

>From 26f51043caf2e3d7cbcf51df3eacea93d5fe15da Mon Sep 17 00:00:00 2001
From: Haopeng Liu <haopliu at google.com>
Date: Fri, 28 Jun 2024 18:15:12 +0000
Subject: [PATCH 1/2] Add ConstantRangeList::subtract(ConstantRange)

---
 llvm/include/llvm/IR/ConstantRangeList.h    |  6 +++
 llvm/lib/IR/ConstantRangeList.cpp           | 59 +++++++++++++++++++++
 llvm/unittests/IR/ConstantRangeListTest.cpp | 44 +++++++++++++++
 3 files changed, 109 insertions(+)

diff --git a/llvm/include/llvm/IR/ConstantRangeList.h b/llvm/include/llvm/IR/ConstantRangeList.h
index 46edaff19e73f..9aae52dac130e 100644
--- a/llvm/include/llvm/IR/ConstantRangeList.h
+++ b/llvm/include/llvm/IR/ConstantRangeList.h
@@ -72,6 +72,12 @@ class [[nodiscard]] ConstantRangeList {
                          APInt(64, Upper, /*isSigned=*/true)));
   }
 
+  void subtract(const ConstantRange &SubRange);
+  void subtract(int64_t Lower, int64_t Upper) {
+    subtract(ConstantRange(APInt(64, Lower, /*isSigned=*/true),
+                           APInt(64, Upper, /*isSigned=*/true)));
+  }
+
   /// Return the range list that results from the union of this
   /// ConstantRangeList with another ConstantRangeList, "CRL".
   ConstantRangeList unionWith(const ConstantRangeList &CRL) const;
diff --git a/llvm/lib/IR/ConstantRangeList.cpp b/llvm/lib/IR/ConstantRangeList.cpp
index 0373524a09f10..2db5de86b3c76 100644
--- a/llvm/lib/IR/ConstantRangeList.cpp
+++ b/llvm/lib/IR/ConstantRangeList.cpp
@@ -81,6 +81,65 @@ void ConstantRangeList::insert(const ConstantRange &NewRange) {
   }
 }
 
+void ConstantRangeList::subtract(const ConstantRange &SubRange) {
+  if (SubRange.isEmptySet())
+    return;
+  assert(!SubRange.isFullSet() && "Do not support full set");
+  assert(SubRange.getLower().slt(SubRange.getUpper()));
+  assert(getBitWidth() == SubRange.getBitWidth());
+  // Handle common cases.
+  if (empty() || Ranges.back().getUpper().sle(SubRange.getLower())) {
+    return;
+  }
+  if (SubRange.getUpper().sle(Ranges.front().getLower())) {
+    return;
+  }
+
+  SmallVector<ConstantRange, 2> Result;
+  auto AppendRange = [&Result](APInt Start, APInt End) {
+    if (Start.slt(End))
+      Result.push_back(ConstantRange(Start, End));
+  };
+  for (auto &Range : Ranges) {
+    if (SubRange.getUpper().sle(Range.getLower()) ||
+        Range.getUpper().sle(SubRange.getLower())) {
+      // "Range" and "SubRange" do not overlap.
+      //       L---U        : Range
+      // L---U              : SubRange (Case1)
+      //             L---U  : SubRange (Case2)
+      Result.push_back(Range);
+    } else if (Range.getLower().sle(SubRange.getLower()) &&
+               SubRange.getUpper().sle(Range.getUpper())) {
+      // "Range" contains "SubRange".
+      //       L---U        : Range
+      //        L-U         : SubRange
+      // Note that ConstantRange::contains(ConstantRange) checks unsigned,
+      // but we need signed checking here.
+      AppendRange(Range.getLower(), SubRange.getLower());
+      AppendRange(SubRange.getUpper(), Range.getUpper());
+    } else if (SubRange.getLower().sle(Range.getLower()) &&
+               Range.getUpper().sle(SubRange.getUpper())) {
+      // "SubRange" contains "Range".
+      //        L-U        : Range
+      //       L---U       : SubRange
+      continue;
+    } else if (Range.getLower().sge(SubRange.getLower()) &&
+               Range.getLower().sle(SubRange.getUpper())) {
+      // "Range" and "SubRange" overlap at the left.
+      //       L---U        : Range
+      //     L---U          : SubRange
+      AppendRange(SubRange.getUpper(), Range.getUpper());
+    } else {
+      // "Range" and "SubRange" overlap at the right.
+      //       L---U        : Range
+      //         L---U      : SubRange
+      AppendRange(Range.getLower(), SubRange.getLower());
+    }
+  }
+
+  Ranges.assign(Result.begin(), Result.end());
+}
+
 ConstantRangeList
 ConstantRangeList::unionWith(const ConstantRangeList &CRL) const {
   assert(getBitWidth() == CRL.getBitWidth() &&
diff --git a/llvm/unittests/IR/ConstantRangeListTest.cpp b/llvm/unittests/IR/ConstantRangeListTest.cpp
index b679dd3a33d5d..da3cb330871b3 100644
--- a/llvm/unittests/IR/ConstantRangeListTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeListTest.cpp
@@ -101,6 +101,50 @@ ConstantRangeList GetCRL(ArrayRef<std::pair<APInt, APInt>> Pairs) {
   return ConstantRangeList(Ranges);
 }
 
+TEST_F(ConstantRangeListTest, Subtract) {
+  APInt AP0 = APInt(64, 0, /*isSigned=*/true);
+  APInt AP2 = APInt(64, 2, /*isSigned=*/true);
+  APInt AP3 = APInt(64, 3, /*isSigned=*/true);
+  APInt AP4 = APInt(64, 4, /*isSigned=*/true);
+  APInt AP8 = APInt(64, 8, /*isSigned=*/true);
+  APInt AP10 = APInt(64, 10, /*isSigned=*/true);
+  APInt AP11 = APInt(64, 11, /*isSigned=*/true);
+  APInt AP12 = APInt(64, 12, /*isSigned=*/true);
+  ConstantRangeList CRL = GetCRL({{AP0, AP4}, {AP8, AP12}});
+
+  // Execute ConstantRangeList::subtract(ConstantRange) and check the result
+  // is expected. Pass "CRL" by value so that subtract() does not affect the
+  // argument in caller.
+  auto SubtractAndCheck = [](ConstantRangeList CRL,
+                             const std::pair<int64_t, int64_t> &Range,
+                             const ConstantRangeList &ExpectedCRL) {
+    CRL.subtract(Range.first, Range.second);
+    EXPECT_EQ(CRL, ExpectedCRL);
+  };
+
+  // No overlap
+  SubtractAndCheck(CRL, {-4, 0}, CRL);
+  SubtractAndCheck(CRL, {4, 8}, CRL);
+  SubtractAndCheck(CRL, {12, 16}, CRL);
+
+  // Overlap (left or right)
+  SubtractAndCheck(CRL, {-4, 2}, GetCRL({{AP2, AP4}, {AP8, AP12}}));
+  SubtractAndCheck(CRL, {-4, 4}, GetCRL({{AP8, AP12}}));
+  SubtractAndCheck(CRL, {-4, 8}, GetCRL({{AP8, AP12}}));
+  SubtractAndCheck(CRL, {10, 16}, GetCRL({{AP0, AP4}, {AP8, AP10}}));
+  SubtractAndCheck(CRL, {8, 16}, GetCRL({{AP0, AP4}}));
+  SubtractAndCheck(CRL, {6, 16}, GetCRL({{AP0, AP4}}));
+
+  // Subset
+  SubtractAndCheck(CRL, {2, 3}, GetCRL({{AP0, AP2}, {AP3, AP4}, {AP8, AP12}}));
+  SubtractAndCheck(CRL, {10, 11},
+                   GetCRL({{AP0, AP4}, {AP8, AP10}, {AP11, AP12}}));
+
+  // Superset
+  SubtractAndCheck(CRL, {0, 12}, GetCRL({}));
+  SubtractAndCheck(CRL, {-4, 16}, GetCRL({}));
+}
+
 TEST_F(ConstantRangeListTest, Union) {
   APInt APN4 = APInt(64, -4, /*isSigned=*/true);
   APInt APN2 = APInt(64, -2, /*isSigned=*/true);

>From 386d94b499e435e28abd471e754db452671ceb67 Mon Sep 17 00:00:00 2001
From: Haopeng Liu <haopliu at google.com>
Date: Fri, 28 Jun 2024 19:58:37 +0000
Subject: [PATCH 2/2] Update code, comments and test

---
 llvm/include/llvm/IR/ConstantRangeList.h    |  4 ----
 llvm/lib/IR/ConstantRangeList.cpp           | 20 +++++++++-----------
 llvm/unittests/IR/ConstantRangeListTest.cpp | 12 ++++++++++--
 3 files changed, 19 insertions(+), 17 deletions(-)

diff --git a/llvm/include/llvm/IR/ConstantRangeList.h b/llvm/include/llvm/IR/ConstantRangeList.h
index 9aae52dac130e..44d1daebe49e4 100644
--- a/llvm/include/llvm/IR/ConstantRangeList.h
+++ b/llvm/include/llvm/IR/ConstantRangeList.h
@@ -73,10 +73,6 @@ class [[nodiscard]] ConstantRangeList {
   }
 
   void subtract(const ConstantRange &SubRange);
-  void subtract(int64_t Lower, int64_t Upper) {
-    subtract(ConstantRange(APInt(64, Lower, /*isSigned=*/true),
-                           APInt(64, Upper, /*isSigned=*/true)));
-  }
 
   /// Return the range list that results from the union of this
   /// ConstantRangeList with another ConstantRangeList, "CRL".
diff --git a/llvm/lib/IR/ConstantRangeList.cpp b/llvm/lib/IR/ConstantRangeList.cpp
index 2db5de86b3c76..4ca4fe4f4c06d 100644
--- a/llvm/lib/IR/ConstantRangeList.cpp
+++ b/llvm/lib/IR/ConstantRangeList.cpp
@@ -82,21 +82,19 @@ void ConstantRangeList::insert(const ConstantRange &NewRange) {
 }
 
 void ConstantRangeList::subtract(const ConstantRange &SubRange) {
-  if (SubRange.isEmptySet())
+  if (SubRange.isEmptySet() || empty())
     return;
   assert(!SubRange.isFullSet() && "Do not support full set");
   assert(SubRange.getLower().slt(SubRange.getUpper()));
   assert(getBitWidth() == SubRange.getBitWidth());
   // Handle common cases.
-  if (empty() || Ranges.back().getUpper().sle(SubRange.getLower())) {
+  if (Ranges.back().getUpper().sle(SubRange.getLower()))
     return;
-  }
-  if (SubRange.getUpper().sle(Ranges.front().getLower())) {
+  if (SubRange.getUpper().sle(Ranges.front().getLower()))
     return;
-  }
 
   SmallVector<ConstantRange, 2> Result;
-  auto AppendRange = [&Result](APInt Start, APInt End) {
+  auto AppendRangeIfNonEmpty = [&Result](APInt Start, APInt End) {
     if (Start.slt(End))
       Result.push_back(ConstantRange(Start, End));
   };
@@ -115,8 +113,8 @@ void ConstantRangeList::subtract(const ConstantRange &SubRange) {
       //        L-U         : SubRange
       // Note that ConstantRange::contains(ConstantRange) checks unsigned,
       // but we need signed checking here.
-      AppendRange(Range.getLower(), SubRange.getLower());
-      AppendRange(SubRange.getUpper(), Range.getUpper());
+      AppendRangeIfNonEmpty(Range.getLower(), SubRange.getLower());
+      AppendRangeIfNonEmpty(SubRange.getUpper(), Range.getUpper());
     } else if (SubRange.getLower().sle(Range.getLower()) &&
                Range.getUpper().sle(SubRange.getUpper())) {
       // "SubRange" contains "Range".
@@ -128,16 +126,16 @@ void ConstantRangeList::subtract(const ConstantRange &SubRange) {
       // "Range" and "SubRange" overlap at the left.
       //       L---U        : Range
       //     L---U          : SubRange
-      AppendRange(SubRange.getUpper(), Range.getUpper());
+      AppendRangeIfNonEmpty(SubRange.getUpper(), Range.getUpper());
     } else {
       // "Range" and "SubRange" overlap at the right.
       //       L---U        : Range
       //         L---U      : SubRange
-      AppendRange(Range.getLower(), SubRange.getLower());
+      AppendRangeIfNonEmpty(Range.getLower(), SubRange.getLower());
     }
   }
 
-  Ranges.assign(Result.begin(), Result.end());
+  Ranges = Result;
 }
 
 ConstantRangeList
diff --git a/llvm/unittests/IR/ConstantRangeListTest.cpp b/llvm/unittests/IR/ConstantRangeListTest.cpp
index da3cb330871b3..d00e0a8ff2a97 100644
--- a/llvm/unittests/IR/ConstantRangeListTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeListTest.cpp
@@ -118,7 +118,8 @@ TEST_F(ConstantRangeListTest, Subtract) {
   auto SubtractAndCheck = [](ConstantRangeList CRL,
                              const std::pair<int64_t, int64_t> &Range,
                              const ConstantRangeList &ExpectedCRL) {
-    CRL.subtract(Range.first, Range.second);
+    CRL.subtract(ConstantRange(APInt(64, Range.first, /*isSigned=*/true),
+                               APInt(64, Range.second, /*isSigned=*/true)));
     EXPECT_EQ(CRL, ExpectedCRL);
   };
 
@@ -127,13 +128,20 @@ TEST_F(ConstantRangeListTest, Subtract) {
   SubtractAndCheck(CRL, {4, 8}, CRL);
   SubtractAndCheck(CRL, {12, 16}, CRL);
 
-  // Overlap (left or right)
+  // Overlap (left, right, or both)
   SubtractAndCheck(CRL, {-4, 2}, GetCRL({{AP2, AP4}, {AP8, AP12}}));
   SubtractAndCheck(CRL, {-4, 4}, GetCRL({{AP8, AP12}}));
   SubtractAndCheck(CRL, {-4, 8}, GetCRL({{AP8, AP12}}));
+  SubtractAndCheck(CRL, {0, 2}, GetCRL({{AP2, AP4}, {AP8, AP12}}));
+  SubtractAndCheck(CRL, {0, 4}, GetCRL({{AP8, AP12}}));
+  SubtractAndCheck(CRL, {0, 8}, GetCRL({{AP8, AP12}}));
+  SubtractAndCheck(CRL, {10, 12}, GetCRL({{AP0, AP4}, {AP8, AP10}}));
+  SubtractAndCheck(CRL, {8, 12}, GetCRL({{AP0, AP4}}));
+  SubtractAndCheck(CRL, {6, 12}, GetCRL({{AP0, AP4}}));
   SubtractAndCheck(CRL, {10, 16}, GetCRL({{AP0, AP4}, {AP8, AP10}}));
   SubtractAndCheck(CRL, {8, 16}, GetCRL({{AP0, AP4}}));
   SubtractAndCheck(CRL, {6, 16}, GetCRL({{AP0, AP4}}));
+  SubtractAndCheck(CRL, {2, 10}, GetCRL({{AP0, AP2}, {AP10, AP12}}));
 
   // Subset
   SubtractAndCheck(CRL, {2, 3}, GetCRL({{AP0, AP2}, {AP3, AP4}, {AP8, AP12}}));



More information about the llvm-commits mailing list