[llvm] e6c2216 - Add ConstantRangeList::unionWith() and ::intersectWith() (#96547)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 25 16:38:07 PDT 2024
Author: Haopeng Liu
Date: 2024-06-25T16:38:03-07:00
New Revision: e6c2216940885ae5b0d6509dac73417c40d6c62f
URL: https://github.com/llvm/llvm-project/commit/e6c2216940885ae5b0d6509dac73417c40d6c62f
DIFF: https://github.com/llvm/llvm-project/commit/e6c2216940885ae5b0d6509dac73417c40d6c62f.diff
LOG: Add ConstantRangeList::unionWith() and ::intersectWith() (#96547)
Add ConstantRangeList::unionWith() and ::intersectWith().
These methods will be used in the "initializes" attribute inference.
https://github.com/llvm/llvm-project/commit/df11106068294fb00f11988d3f48336e2cbed364
Added:
Modified:
llvm/include/llvm/IR/ConstantRangeList.h
llvm/lib/IR/ConstantRangeList.cpp
llvm/unittests/IR/ConstantRangeListTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/ConstantRangeList.h b/llvm/include/llvm/IR/ConstantRangeList.h
index f696bd6cc6a3d..46edaff19e73f 100644
--- a/llvm/include/llvm/IR/ConstantRangeList.h
+++ b/llvm/include/llvm/IR/ConstantRangeList.h
@@ -72,6 +72,14 @@ class [[nodiscard]] ConstantRangeList {
APInt(64, Upper, /*isSigned=*/true)));
}
+ /// Return the range list that results from the union of this
+ /// ConstantRangeList with another ConstantRangeList, "CRL".
+ ConstantRangeList unionWith(const ConstantRangeList &CRL) const;
+
+ /// Return the range list that results from the intersection of this
+ /// ConstantRangeList with another ConstantRangeList, "CRL".
+ ConstantRangeList intersectWith(const ConstantRangeList &CRL) const;
+
/// Return true if this range list is equal to another range list.
bool operator==(const ConstantRangeList &CRL) const {
return Ranges == CRL.Ranges;
diff --git a/llvm/lib/IR/ConstantRangeList.cpp b/llvm/lib/IR/ConstantRangeList.cpp
index 2cc483d4e4962..0373524a09f10 100644
--- a/llvm/lib/IR/ConstantRangeList.cpp
+++ b/llvm/lib/IR/ConstantRangeList.cpp
@@ -81,6 +81,95 @@ void ConstantRangeList::insert(const ConstantRange &NewRange) {
}
}
+ConstantRangeList
+ConstantRangeList::unionWith(const ConstantRangeList &CRL) const {
+ assert(getBitWidth() == CRL.getBitWidth() &&
+ "ConstantRangeList bitwidths don't agree!");
+ // Handle common cases.
+ if (empty())
+ return CRL;
+ if (CRL.empty())
+ return *this;
+
+ ConstantRangeList Result;
+ size_t i = 0, j = 0;
+ // "PreviousRange" tracks the lowest unioned range that is being processed.
+ // Its lower is fixed and the upper may be updated over iterations.
+ ConstantRange PreviousRange(getBitWidth(), false);
+ if (Ranges[i].getLower().slt(CRL.Ranges[j].getLower())) {
+ PreviousRange = Ranges[i++];
+ } else {
+ PreviousRange = CRL.Ranges[j++];
+ }
+
+ // Try to union "PreviousRange" and "CR". If they are disjoint, push
+ // "PreviousRange" to the result and assign it to "CR", a new union range.
+ // Otherwise, update the upper of "PreviousRange" to cover "CR". Note that,
+ // the lower of "PreviousRange" is always less or equal the lower of "CR".
+ auto UnionAndUpdateRange = [&PreviousRange,
+ &Result](const ConstantRange &CR) {
+ if (PreviousRange.getUpper().slt(CR.getLower())) {
+ Result.Ranges.push_back(PreviousRange);
+ PreviousRange = CR;
+ } else {
+ PreviousRange = ConstantRange(
+ PreviousRange.getLower(),
+ APIntOps::smax(PreviousRange.getUpper(), CR.getUpper()));
+ }
+ };
+ while (i < size() || j < CRL.size()) {
+ if (j == CRL.size() ||
+ (i < size() && Ranges[i].getLower().slt(CRL.Ranges[j].getLower()))) {
+ // Merge PreviousRange with this.
+ UnionAndUpdateRange(Ranges[i++]);
+ } else {
+ // Merge PreviousRange with CRL.
+ UnionAndUpdateRange(CRL.Ranges[j++]);
+ }
+ }
+ Result.Ranges.push_back(PreviousRange);
+ return Result;
+}
+
+ConstantRangeList
+ConstantRangeList::intersectWith(const ConstantRangeList &CRL) const {
+ assert(getBitWidth() == CRL.getBitWidth() &&
+ "ConstantRangeList bitwidths don't agree!");
+
+ // Handle common cases.
+ if (empty())
+ return *this;
+ if (CRL.empty())
+ return CRL;
+
+ ConstantRangeList Result;
+ size_t i = 0, j = 0;
+ while (i < size() && j < CRL.size()) {
+ auto &Range = this->Ranges[i];
+ auto &OtherRange = CRL.Ranges[j];
+
+ // The intersection of two Ranges is (max(lowers), min(uppers)), and it's
+ // possible that max(lowers) > min(uppers) if they don't have intersection.
+ // Add the intersection to result only if it's non-empty.
+ // To keep simple, we don't call ConstantRange::intersectWith() as it
+ // considers the complex upper wrapped case and may result two ranges,
+ // like (2, 8) && (6, 4) = {(2, 4), (6, 8)}.
+ APInt Start = APIntOps::smax(Range.getLower(), OtherRange.getLower());
+ APInt End = APIntOps::smin(Range.getUpper(), OtherRange.getUpper());
+ if (Start.slt(End))
+ Result.Ranges.push_back(ConstantRange(Start, End));
+
+ // Move to the next Range in one list determined by the uppers.
+ // For example: A = {(0, 2), (4, 8)}; B = {(-2, 5), (6, 10)}
+ // We need to intersect three pairs: A0 && B0; A1 && B0; A1 && B1.
+ if (Range.getUpper().slt(OtherRange.getUpper()))
+ i++;
+ else
+ j++;
+ }
+ return Result;
+}
+
void ConstantRangeList::print(raw_ostream &OS) const {
interleaveComma(Ranges, OS, [&](ConstantRange CR) {
OS << "(" << CR.getLower() << ", " << CR.getUpper() << ")";
diff --git a/llvm/unittests/IR/ConstantRangeListTest.cpp b/llvm/unittests/IR/ConstantRangeListTest.cpp
index 144b5ccdc1fc0..b679dd3a33d5d 100644
--- a/llvm/unittests/IR/ConstantRangeListTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeListTest.cpp
@@ -94,4 +94,116 @@ TEST_F(ConstantRangeListTest, Insert) {
EXPECT_TRUE(CRL == Expected);
}
+ConstantRangeList GetCRL(ArrayRef<std::pair<APInt, APInt>> Pairs) {
+ SmallVector<ConstantRange, 2> Ranges;
+ for (auto &[Start, End] : Pairs)
+ Ranges.push_back(ConstantRange(Start, End));
+ return ConstantRangeList(Ranges);
+}
+
+TEST_F(ConstantRangeListTest, Union) {
+ APInt APN4 = APInt(64, -4, /*isSigned=*/true);
+ APInt APN2 = APInt(64, -2, /*isSigned=*/true);
+ APInt AP0 = APInt(64, 0, /*isSigned=*/true);
+ APInt AP2 = APInt(64, 2, /*isSigned=*/true);
+ APInt AP4 = APInt(64, 4, /*isSigned=*/true);
+ APInt AP6 = APInt(64, 6, /*isSigned=*/true);
+ APInt AP7 = APInt(64, 7, /*isSigned=*/true);
+ APInt AP8 = APInt(64, 8, /*isSigned=*/true);
+ APInt AP10 = APInt(64, 10, /*isSigned=*/true);
+ APInt AP11 = APInt(64, 11, /*isSigned=*/true);
+ APInt AP12 = APInt(64, 12, /*isSigned=*/true);
+ APInt AP16 = APInt(64, 16, /*isSigned=*/true);
+ APInt AP18 = APInt(64, 18, /*isSigned=*/true);
+ ConstantRangeList CRL = GetCRL({{AP0, AP4}, {AP8, AP12}});
+
+ // Union with a subset.
+ ConstantRangeList Empty;
+ EXPECT_EQ(CRL.unionWith(Empty), CRL);
+ EXPECT_EQ(Empty.unionWith(CRL), CRL);
+
+ EXPECT_EQ(CRL.unionWith(GetCRL({{AP0, AP2}})), CRL);
+ EXPECT_EQ(CRL.unionWith(GetCRL({{AP10, AP12}})), CRL);
+
+ EXPECT_EQ(CRL.unionWith(GetCRL({{AP0, AP2}, {AP8, AP10}})), CRL);
+ EXPECT_EQ(CRL.unionWith(GetCRL({{AP0, AP2}, {AP10, AP12}})), CRL);
+ EXPECT_EQ(CRL.unionWith(GetCRL({{AP2, AP4}, {AP8, AP10}})), CRL);
+ EXPECT_EQ(CRL.unionWith(GetCRL({{AP2, AP4}, {AP10, AP12}})), CRL);
+
+ EXPECT_EQ(CRL.unionWith(GetCRL({{AP0, AP4}, {AP8, AP10}, {AP11, AP12}})),
+ CRL);
+
+ EXPECT_EQ(CRL.unionWith(CRL), CRL);
+
+ // Union with new ranges.
+ EXPECT_EQ(CRL.unionWith(GetCRL({{APN4, APN2}})),
+ GetCRL({{APN4, APN2}, {AP0, AP4}, {AP8, AP12}}));
+ EXPECT_EQ(CRL.unionWith(GetCRL({{AP6, AP7}})),
+ GetCRL({{AP0, AP4}, {AP6, AP7}, {AP8, AP12}}));
+ EXPECT_EQ(CRL.unionWith(GetCRL({{AP16, AP18}})),
+ GetCRL({{AP0, AP4}, {AP8, AP12}, {AP16, AP18}}));
+
+ EXPECT_EQ(CRL.unionWith(GetCRL({{APN2, AP2}})),
+ GetCRL({{APN2, AP4}, {AP8, AP12}}));
+ EXPECT_EQ(CRL.unionWith(GetCRL({{AP2, AP6}})),
+ GetCRL({{AP0, AP6}, {AP8, AP12}}));
+ EXPECT_EQ(CRL.unionWith(GetCRL({{AP10, AP16}})),
+ GetCRL({{AP0, AP4}, {AP8, AP16}}));
+
+ EXPECT_EQ(CRL.unionWith(GetCRL({{APN2, AP10}})), GetCRL({{APN2, AP12}}));
+ EXPECT_EQ(CRL.unionWith(GetCRL({{AP2, AP10}})), GetCRL({{AP0, AP12}}));
+ EXPECT_EQ(CRL.unionWith(GetCRL({{AP4, AP16}})), GetCRL({{AP0, AP16}}));
+ EXPECT_EQ(CRL.unionWith(GetCRL({{APN2, AP16}})), GetCRL({{APN2, AP16}}));
+}
+
+TEST_F(ConstantRangeListTest, Intersect) {
+ APInt APN2 = APInt(64, -2, /*isSigned=*/true);
+ APInt AP0 = APInt(64, 0, /*isSigned=*/true);
+ APInt AP2 = APInt(64, 2, /*isSigned=*/true);
+ APInt AP4 = APInt(64, 4, /*isSigned=*/true);
+ APInt AP6 = APInt(64, 6, /*isSigned=*/true);
+ APInt AP7 = APInt(64, 7, /*isSigned=*/true);
+ APInt AP8 = APInt(64, 8, /*isSigned=*/true);
+ APInt AP10 = APInt(64, 10, /*isSigned=*/true);
+ APInt AP11 = APInt(64, 11, /*isSigned=*/true);
+ APInt AP12 = APInt(64, 12, /*isSigned=*/true);
+ APInt AP16 = APInt(64, 16, /*isSigned=*/true);
+ ConstantRangeList CRL = GetCRL({{AP0, AP4}, {AP8, AP12}});
+
+ // No intersection.
+ ConstantRangeList Empty;
+ EXPECT_EQ(CRL.intersectWith(Empty), Empty);
+ EXPECT_EQ(Empty.intersectWith(CRL), Empty);
+
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP0}})), Empty);
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{AP6, AP8}})), Empty);
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{AP12, AP16}})), Empty);
+
+ // Single intersect range.
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP2}})), GetCRL({{AP0, AP2}}));
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP6}})), GetCRL({{AP0, AP4}}));
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{AP2, AP4}})), GetCRL({{AP2, AP4}}));
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{AP2, AP6}})), GetCRL({{AP2, AP4}}));
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{AP6, AP10}})), GetCRL({{AP8, AP10}}));
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{AP6, AP16}})), GetCRL({{AP8, AP12}}));
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{AP10, AP12}})), GetCRL({{AP10, AP12}}));
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{AP10, AP16}})), GetCRL({{AP10, AP12}}));
+
+ // Multiple intersect ranges.
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP10}})),
+ GetCRL({{AP0, AP4}, {AP8, AP10}}));
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP16}})), CRL);
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{AP2, AP10}})),
+ GetCRL({{AP2, AP4}, {AP8, AP10}}));
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{AP2, AP16}})),
+ GetCRL({{AP2, AP4}, {AP8, AP12}}));
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP2}, {AP6, AP10}})),
+ GetCRL({{AP0, AP2}, {AP8, AP10}}));
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{AP2, AP6}, {AP10, AP16}})),
+ GetCRL({{AP2, AP4}, {AP10, AP12}}));
+ EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP2}, {AP7, AP10}, {AP11, AP16}})),
+ GetCRL({{AP0, AP2}, {AP8, AP10}, {AP11, AP12}}));
+ EXPECT_EQ(CRL.intersectWith(CRL), CRL);
+}
+
} // anonymous namespace
More information about the llvm-commits
mailing list