[llvm] [EquivalenceClasses] Introduce erase member function (PR #134660)

donald chen via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 7 08:20:20 PDT 2025


https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/134660

>From b3f964bb3b38e9248f938e204c0090c986b29586 Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Mon, 7 Apr 2025 14:29:57 +0000
Subject: [PATCH 1/2] [EquivalenceClasses] Introduce erase member function

Introduce 'erase(const ElemTy &V)' member function to allow the
deletion of a certain value from EquivClasses. This is essential for
certain scenarios that require modifying the contents of EquivClasses.

This path also incidentally fixes a problem of inaccurate leader
setting when EquivClasses unions two classes. This problem only arises
in the presence of erase function.
---
 llvm/include/llvm/ADT/EquivalenceClasses.h    | 42 ++++++++++++++--
 llvm/unittests/ADT/EquivalenceClassesTest.cpp | 49 +++++++++++++++++++
 2 files changed, 88 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/ADT/EquivalenceClasses.h b/llvm/include/llvm/ADT/EquivalenceClasses.h
index e0a7af9421c35..e2e6b626993ef 100644
--- a/llvm/include/llvm/ADT/EquivalenceClasses.h
+++ b/llvm/include/llvm/ADT/EquivalenceClasses.h
@@ -220,6 +220,40 @@ template <class ElemTy> class EquivalenceClasses {
     return *ECV;
   }
 
+  /// erase - Erase a value from the union/find set, return if erase succeed.
+  bool erase(const ElemTy &V) {
+    if (!TheMapping.contains(V)) return false;
+    const ECValue *cur = TheMapping[V];
+    const ECValue *next = cur->getNext();
+    if (cur->isLeader()) {
+      if (next) {
+        next->Leader = cur->Leader;
+        auto nn = next->Next;
+        next->Next = (const ECValue*)((intptr_t)nn | (intptr_t)1);
+      }
+    } else {
+      const ECValue *leader = findLeader(V).Node;
+      const ECValue *pre = leader;
+      while (pre->getNext() != cur) {
+        pre = pre->getNext();
+      }
+      if (!next) {
+        pre->Next = nullptr;
+        leader->Leader = pre;
+      } else {
+        pre->Next = (const ECValue*)((intptr_t)next | (intptr_t)pre->isLeader());
+        next->Leader = pre;
+      }
+    }
+    TheMapping.erase(V);
+    for (auto I = Members.begin(); I != Members.end(); I++) {
+      if (*I == cur) {
+        Members.erase(I);
+        break;
+      }
+    }
+    return true;
+  }
   /// findLeader - Given a value in the set, return a member iterator for the
   /// equivalence class it is in.  This does the path-compression part that
   /// makes union-find "union findy".  This returns an end iterator if the value
@@ -247,7 +281,9 @@ template <class ElemTy> class EquivalenceClasses {
     // Otherwise, this is a real union operation.  Set the end of the L1 list to
     // point to the L2 leader node.
     const ECValue &L1LV = *L1.Node, &L2LV = *L2.Node;
-    L1LV.getEndOfList()->setNext(&L2LV);
+    const ECValue *L1LastV = L1LV.getEndOfList();
+
+    L1LastV->setNext(&L2LV);
 
     // Update L1LV's end of list pointer.
     L1LV.Leader = L2LV.getEndOfList();
@@ -255,8 +291,8 @@ template <class ElemTy> class EquivalenceClasses {
     // Clear L2's leader flag:
     L2LV.Next = L2LV.getNext();
 
-    // L2's leader is now L1.
-    L2LV.Leader = &L1LV;
+    // L2's leader is now last value of L1.
+    L2LV.Leader = L1LastV;
     return L1;
   }
 
diff --git a/llvm/unittests/ADT/EquivalenceClassesTest.cpp b/llvm/unittests/ADT/EquivalenceClassesTest.cpp
index ff243f51102fb..3d5c48eb8e1b6 100644
--- a/llvm/unittests/ADT/EquivalenceClassesTest.cpp
+++ b/llvm/unittests/ADT/EquivalenceClassesTest.cpp
@@ -59,6 +59,55 @@ TEST(EquivalenceClassesTest, SimpleMerge2) {
       EXPECT_TRUE(EqClasses.isEquivalent(i, j));
 }
 
+TEST(EquivalenceClassesTest, SimpleErase1) {
+  EquivalenceClasses<int> EqClasses;
+  // Check that erase head success.
+  // After erase A from (A, B ,C, D), <B, C, D> belong to one set.
+  EqClasses.unionSets(0, 1);
+  EqClasses.unionSets(2, 3);
+  EqClasses.unionSets(0, 2);
+  EXPECT_TRUE(EqClasses.erase(0));
+  for (int i = 1; i < 4; ++i)
+    for (int j = 1; j < 4; ++j)
+      EXPECT_TRUE(EqClasses.isEquivalent(i, j));
+}
+
+TEST(EquivalenceClassesTest, SimpleErase2) {
+  EquivalenceClasses<int> EqClasses;
+  // Check that erase tail success.
+  // After erase D from (A, B ,C, D), <A, B, C> belong to one set.
+  EqClasses.unionSets(0, 1);
+  EqClasses.unionSets(2, 3);
+  EqClasses.unionSets(0, 2);
+  EXPECT_TRUE(EqClasses.erase(3));
+  for (int i = 0; i < 3; ++i)
+    for (int j = 0; j < 3; ++j)
+      EXPECT_TRUE(EqClasses.isEquivalent(i, j));
+}
+
+TEST(EquivalenceClassesTest, SimpleErase3) {
+  EquivalenceClasses<int> EqClasses;
+  // Check that erase a value in the middle success.
+  // After erase B from (A, B ,C, D), <A, C, D> belong to one set.
+  EqClasses.unionSets(0, 1);
+  EqClasses.unionSets(2, 3);
+  EqClasses.unionSets(0, 2);
+  EXPECT_TRUE(EqClasses.erase(1));
+  for (int i = 0; i < 3; ++i)
+    for (int j = 0; j < 3; ++j)
+      EXPECT_TRUE(EqClasses.isEquivalent(i, j) ^ ((i == 1) ^ (j == 1)));
+}
+
+TEST(EquivalenceClassesTest, SimpleErase4) {
+  EquivalenceClasses<int> EqClasses;
+  // Check that erase a single class success.
+  EqClasses.insert(0);
+  EXPECT_TRUE(EqClasses.getNumClasses() == 1);
+  EXPECT_TRUE(EqClasses.erase(0));
+  EXPECT_TRUE(EqClasses.getNumClasses() == 0);
+  EXPECT_FALSE(EqClasses.erase(1));
+}
+
 TEST(EquivalenceClassesTest, TwoSets) {
   EquivalenceClasses<int> EqClasses;
   // Form sets of odd and even numbers, check that we split them into these

>From 8fcec6df69ec642ee809dda8eec597c7348ac4ab Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Mon, 7 Apr 2025 23:19:50 +0800
Subject: [PATCH 2/2] fix coding style

---
 llvm/include/llvm/ADT/EquivalenceClasses.h    | 38 +++++++++----------
 llvm/unittests/ADT/EquivalenceClassesTest.cpp |  2 +-
 2 files changed, 20 insertions(+), 20 deletions(-)

diff --git a/llvm/include/llvm/ADT/EquivalenceClasses.h b/llvm/include/llvm/ADT/EquivalenceClasses.h
index e2e6b626993ef..13a93089dbaf3 100644
--- a/llvm/include/llvm/ADT/EquivalenceClasses.h
+++ b/llvm/include/llvm/ADT/EquivalenceClasses.h
@@ -78,11 +78,13 @@ template <class ElemTy> class EquivalenceClasses {
     // ECValue ctor - Start out with EndOfList pointing to this node, Next is
     // Null, isLeader = true.
     ECValue(const ElemTy &Elt)
-      : Leader(this), Next((ECValue*)(intptr_t)1), Data(Elt) {}
+        : Leader(this), Next((ECValue *)(intptr_t)1), Data(Elt) {}
 
     const ECValue *getLeader() const {
-      if (isLeader()) return this;
-      if (Leader->isLeader()) return Leader;
+      if (isLeader())
+        return this;
+      if (Leader->isLeader())
+        return Leader;
       // Path compression.
       return Leader = Leader->getLeader();
     }
@@ -94,12 +96,12 @@ template <class ElemTy> class EquivalenceClasses {
 
     void setNext(const ECValue *NewNext) const {
       assert(getNext() == nullptr && "Already has a next pointer!");
-      Next = (const ECValue*)((intptr_t)NewNext | (intptr_t)isLeader());
+      Next = (const ECValue *)((intptr_t)NewNext | (intptr_t)isLeader());
     }
 
   public:
-    ECValue(const ECValue &RHS) : Leader(this), Next((ECValue*)(intptr_t)1),
-                                  Data(RHS.Data) {
+    ECValue(const ECValue &RHS)
+        : Leader(this), Next((ECValue *)(intptr_t)1), Data(RHS.Data) {
       // Only support copying of singleton nodes.
       assert(RHS.isLeader() && RHS.getNext() == nullptr && "Not a singleton!");
     }
@@ -108,7 +110,7 @@ template <class ElemTy> class EquivalenceClasses {
     const ElemTy &getData() const { return Data; }
 
     const ECValue *getNext() const {
-      return (ECValue*)((intptr_t)Next & ~(intptr_t)1);
+      return (ECValue *)((intptr_t)Next & ~(intptr_t)1);
     }
   };
 
@@ -123,9 +125,7 @@ template <class ElemTy> class EquivalenceClasses {
 
 public:
   EquivalenceClasses() = default;
-  EquivalenceClasses(const EquivalenceClasses &RHS) {
-    operator=(RHS);
-  }
+  EquivalenceClasses(const EquivalenceClasses &RHS) { operator=(RHS); }
 
   EquivalenceClasses &operator=(const EquivalenceClasses &RHS) {
     TheMapping.clear();
@@ -159,9 +159,7 @@ template <class ElemTy> class EquivalenceClasses {
     return member_iterator(ECV.isLeader() ? &ECV : nullptr);
   }
 
-  member_iterator member_end() const {
-    return member_iterator(nullptr);
-  }
+  member_iterator member_end() const { return member_iterator(nullptr); }
 
   iterator_range<member_iterator> members(const ECValue &ECV) const {
     return make_range(member_begin(ECV), member_end());
@@ -222,14 +220,14 @@ template <class ElemTy> class EquivalenceClasses {
 
   /// erase - Erase a value from the union/find set, return if erase succeed.
   bool erase(const ElemTy &V) {
-    if (!TheMapping.contains(V)) return false;
+    if (!TheMapping.contains(V))
+      return false;
     const ECValue *cur = TheMapping[V];
     const ECValue *next = cur->getNext();
     if (cur->isLeader()) {
       if (next) {
         next->Leader = cur->Leader;
-        auto nn = next->Next;
-        next->Next = (const ECValue*)((intptr_t)nn | (intptr_t)1);
+        next->Next = (const ECValue *)((intptr_t)next->Next | (intptr_t)1);
       }
     } else {
       const ECValue *leader = findLeader(V).Node;
@@ -241,7 +239,8 @@ template <class ElemTy> class EquivalenceClasses {
         pre->Next = nullptr;
         leader->Leader = pre;
       } else {
-        pre->Next = (const ECValue*)((intptr_t)next | (intptr_t)pre->isLeader());
+        pre->Next =
+            (const ECValue *)((intptr_t)next | (intptr_t)pre->isLeader());
         next->Leader = pre;
       }
     }
@@ -276,7 +275,8 @@ template <class ElemTy> class EquivalenceClasses {
   }
   member_iterator unionSets(member_iterator L1, member_iterator L2) {
     assert(L1 != member_end() && L2 != member_end() && "Illegal inputs!");
-    if (L1 == L2) return L1;   // Unifying the same two sets, noop.
+    if (L1 == L2)
+      return L1; // Unifying the same two sets, noop.
 
     // Otherwise, this is a real union operation.  Set the end of the L1 list to
     // point to the L2 leader node.
@@ -334,7 +334,7 @@ template <class ElemTy> class EquivalenceClasses {
       return *this;
     }
 
-    member_iterator operator++(int) {    // postincrement operators.
+    member_iterator operator++(int) { // postincrement operators.
       member_iterator tmp = *this;
       ++*this;
       return tmp;
diff --git a/llvm/unittests/ADT/EquivalenceClassesTest.cpp b/llvm/unittests/ADT/EquivalenceClassesTest.cpp
index 3d5c48eb8e1b6..1467f46fbf176 100644
--- a/llvm/unittests/ADT/EquivalenceClassesTest.cpp
+++ b/llvm/unittests/ADT/EquivalenceClassesTest.cpp
@@ -158,4 +158,4 @@ TYPED_TEST_P(ParameterizedTest, MultipleSets) {
         EXPECT_FALSE(EqClasses.isEquivalent(i, j));
 }
 
-} // llvm
+} // namespace llvm



More information about the llvm-commits mailing list