[llvm] [EquivalenceClasses] Introduce erase member function (PR #134660)

donald chen via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 9 05:52:57 PDT 2025


https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/134660

>From b3f964bb3b38e9248f938e204c0090c986b29586 Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Mon, 7 Apr 2025 14:29:57 +0000
Subject: [PATCH 1/3] [EquivalenceClasses] Introduce erase member function

Introduce 'erase(const ElemTy &V)' member function to allow the
deletion of a certain value from EquivClasses. This is essential for
certain scenarios that require modifying the contents of EquivClasses.

This path also incidentally fixes a problem of inaccurate leader
setting when EquivClasses unions two classes. This problem only arises
in the presence of erase function.
---
 llvm/include/llvm/ADT/EquivalenceClasses.h    | 42 ++++++++++++++--
 llvm/unittests/ADT/EquivalenceClassesTest.cpp | 49 +++++++++++++++++++
 2 files changed, 88 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/ADT/EquivalenceClasses.h b/llvm/include/llvm/ADT/EquivalenceClasses.h
index e0a7af9421c35..e2e6b626993ef 100644
--- a/llvm/include/llvm/ADT/EquivalenceClasses.h
+++ b/llvm/include/llvm/ADT/EquivalenceClasses.h
@@ -220,6 +220,40 @@ template <class ElemTy> class EquivalenceClasses {
     return *ECV;
   }
 
+  /// erase - Erase a value from the union/find set, return if erase succeed.
+  bool erase(const ElemTy &V) {
+    if (!TheMapping.contains(V)) return false;
+    const ECValue *cur = TheMapping[V];
+    const ECValue *next = cur->getNext();
+    if (cur->isLeader()) {
+      if (next) {
+        next->Leader = cur->Leader;
+        auto nn = next->Next;
+        next->Next = (const ECValue*)((intptr_t)nn | (intptr_t)1);
+      }
+    } else {
+      const ECValue *leader = findLeader(V).Node;
+      const ECValue *pre = leader;
+      while (pre->getNext() != cur) {
+        pre = pre->getNext();
+      }
+      if (!next) {
+        pre->Next = nullptr;
+        leader->Leader = pre;
+      } else {
+        pre->Next = (const ECValue*)((intptr_t)next | (intptr_t)pre->isLeader());
+        next->Leader = pre;
+      }
+    }
+    TheMapping.erase(V);
+    for (auto I = Members.begin(); I != Members.end(); I++) {
+      if (*I == cur) {
+        Members.erase(I);
+        break;
+      }
+    }
+    return true;
+  }
   /// findLeader - Given a value in the set, return a member iterator for the
   /// equivalence class it is in.  This does the path-compression part that
   /// makes union-find "union findy".  This returns an end iterator if the value
@@ -247,7 +281,9 @@ template <class ElemTy> class EquivalenceClasses {
     // Otherwise, this is a real union operation.  Set the end of the L1 list to
     // point to the L2 leader node.
     const ECValue &L1LV = *L1.Node, &L2LV = *L2.Node;
-    L1LV.getEndOfList()->setNext(&L2LV);
+    const ECValue *L1LastV = L1LV.getEndOfList();
+
+    L1LastV->setNext(&L2LV);
 
     // Update L1LV's end of list pointer.
     L1LV.Leader = L2LV.getEndOfList();
@@ -255,8 +291,8 @@ template <class ElemTy> class EquivalenceClasses {
     // Clear L2's leader flag:
     L2LV.Next = L2LV.getNext();
 
-    // L2's leader is now L1.
-    L2LV.Leader = &L1LV;
+    // L2's leader is now last value of L1.
+    L2LV.Leader = L1LastV;
     return L1;
   }
 
diff --git a/llvm/unittests/ADT/EquivalenceClassesTest.cpp b/llvm/unittests/ADT/EquivalenceClassesTest.cpp
index ff243f51102fb..3d5c48eb8e1b6 100644
--- a/llvm/unittests/ADT/EquivalenceClassesTest.cpp
+++ b/llvm/unittests/ADT/EquivalenceClassesTest.cpp
@@ -59,6 +59,55 @@ TEST(EquivalenceClassesTest, SimpleMerge2) {
       EXPECT_TRUE(EqClasses.isEquivalent(i, j));
 }
 
+TEST(EquivalenceClassesTest, SimpleErase1) {
+  EquivalenceClasses<int> EqClasses;
+  // Check that erase head success.
+  // After erase A from (A, B ,C, D), <B, C, D> belong to one set.
+  EqClasses.unionSets(0, 1);
+  EqClasses.unionSets(2, 3);
+  EqClasses.unionSets(0, 2);
+  EXPECT_TRUE(EqClasses.erase(0));
+  for (int i = 1; i < 4; ++i)
+    for (int j = 1; j < 4; ++j)
+      EXPECT_TRUE(EqClasses.isEquivalent(i, j));
+}
+
+TEST(EquivalenceClassesTest, SimpleErase2) {
+  EquivalenceClasses<int> EqClasses;
+  // Check that erase tail success.
+  // After erase D from (A, B ,C, D), <A, B, C> belong to one set.
+  EqClasses.unionSets(0, 1);
+  EqClasses.unionSets(2, 3);
+  EqClasses.unionSets(0, 2);
+  EXPECT_TRUE(EqClasses.erase(3));
+  for (int i = 0; i < 3; ++i)
+    for (int j = 0; j < 3; ++j)
+      EXPECT_TRUE(EqClasses.isEquivalent(i, j));
+}
+
+TEST(EquivalenceClassesTest, SimpleErase3) {
+  EquivalenceClasses<int> EqClasses;
+  // Check that erase a value in the middle success.
+  // After erase B from (A, B ,C, D), <A, C, D> belong to one set.
+  EqClasses.unionSets(0, 1);
+  EqClasses.unionSets(2, 3);
+  EqClasses.unionSets(0, 2);
+  EXPECT_TRUE(EqClasses.erase(1));
+  for (int i = 0; i < 3; ++i)
+    for (int j = 0; j < 3; ++j)
+      EXPECT_TRUE(EqClasses.isEquivalent(i, j) ^ ((i == 1) ^ (j == 1)));
+}
+
+TEST(EquivalenceClassesTest, SimpleErase4) {
+  EquivalenceClasses<int> EqClasses;
+  // Check that erase a single class success.
+  EqClasses.insert(0);
+  EXPECT_TRUE(EqClasses.getNumClasses() == 1);
+  EXPECT_TRUE(EqClasses.erase(0));
+  EXPECT_TRUE(EqClasses.getNumClasses() == 0);
+  EXPECT_FALSE(EqClasses.erase(1));
+}
+
 TEST(EquivalenceClassesTest, TwoSets) {
   EquivalenceClasses<int> EqClasses;
   // Form sets of odd and even numbers, check that we split them into these

>From acf94ab53fd6c9004b51d8cb50a38f43ee7a9a42 Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Tue, 8 Apr 2025 05:16:20 +0000
Subject: [PATCH 2/3] fix comment and style ...

---
 llvm/include/llvm/ADT/EquivalenceClasses.h | 50 ++++++++++++----------
 1 file changed, 27 insertions(+), 23 deletions(-)

diff --git a/llvm/include/llvm/ADT/EquivalenceClasses.h b/llvm/include/llvm/ADT/EquivalenceClasses.h
index e2e6b626993ef..ff3f74a29e45d 100644
--- a/llvm/include/llvm/ADT/EquivalenceClasses.h
+++ b/llvm/include/llvm/ADT/EquivalenceClasses.h
@@ -222,38 +222,44 @@ template <class ElemTy> class EquivalenceClasses {
 
   /// erase - Erase a value from the union/find set, return if erase succeed.
   bool erase(const ElemTy &V) {
-    if (!TheMapping.contains(V)) return false;
-    const ECValue *cur = TheMapping[V];
-    const ECValue *next = cur->getNext();
-    if (cur->isLeader()) {
-      if (next) {
-        next->Leader = cur->Leader;
-        auto nn = next->Next;
-        next->Next = (const ECValue*)((intptr_t)nn | (intptr_t)1);
+    if (!TheMapping.contains(V))
+      return false;
+    const ECValue *Cur = TheMapping[V];
+    const ECValue *Next = Cur->getNext();
+    if (Cur->isLeader()) {
+      if (Next) {
+        Next->Leader = Cur->Leader;
+        Next->Next = (const ECValue *)((intptr_t)Next->Next | (intptr_t)1);
+        const ECValue *newLeader = Next;
+        while ((Next = Next->getNext())) {
+          Next->Leader = newLeader;
+        }
       }
     } else {
-      const ECValue *leader = findLeader(V).Node;
-      const ECValue *pre = leader;
-      while (pre->getNext() != cur) {
-        pre = pre->getNext();
+      const ECValue *Leader = findLeader(V).Node;
+      const ECValue *Pre = Leader;
+      while (Pre->getNext() != Cur) {
+        Pre = Pre->getNext();
       }
-      if (!next) {
-        pre->Next = nullptr;
-        leader->Leader = pre;
+      if (!Next) {
+        Pre->Next = nullptr;
+        Leader->Leader = Pre;
       } else {
-        pre->Next = (const ECValue*)((intptr_t)next | (intptr_t)pre->isLeader());
-        next->Leader = pre;
+        Pre->Next =
+            (const ECValue *)((intptr_t)Next | (intptr_t)Pre->isLeader());
+        Next->Leader = Pre;
       }
     }
     TheMapping.erase(V);
     for (auto I = Members.begin(); I != Members.end(); I++) {
-      if (*I == cur) {
+      if (*I == Cur) {
         Members.erase(I);
         break;
       }
     }
     return true;
   }
+
   /// findLeader - Given a value in the set, return a member iterator for the
   /// equivalence class it is in.  This does the path-compression part that
   /// makes union-find "union findy".  This returns an end iterator if the value
@@ -281,9 +287,7 @@ template <class ElemTy> class EquivalenceClasses {
     // Otherwise, this is a real union operation.  Set the end of the L1 list to
     // point to the L2 leader node.
     const ECValue &L1LV = *L1.Node, &L2LV = *L2.Node;
-    const ECValue *L1LastV = L1LV.getEndOfList();
-
-    L1LastV->setNext(&L2LV);
+    L1LV.getEndOfList()->setNext(&L2LV);
 
     // Update L1LV's end of list pointer.
     L1LV.Leader = L2LV.getEndOfList();
@@ -291,8 +295,8 @@ template <class ElemTy> class EquivalenceClasses {
     // Clear L2's leader flag:
     L2LV.Next = L2LV.getNext();
 
-    // L2's leader is now last value of L1.
-    L2LV.Leader = L1LastV;
+    // L2's leader is now L1.
+    L2LV.Leader = &L1LV;
     return L1;
   }
 

>From 2c62332e186736d05cd603af4bfbdfd6ce9d73ed Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Wed, 9 Apr 2025 20:51:16 +0800
Subject: [PATCH 3/3] Update the code based on comments

---
 llvm/include/llvm/ADT/EquivalenceClasses.h | 25 ++++++++++++++++------
 1 file changed, 18 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/ADT/EquivalenceClasses.h b/llvm/include/llvm/ADT/EquivalenceClasses.h
index ff3f74a29e45d..134effc032042 100644
--- a/llvm/include/llvm/ADT/EquivalenceClasses.h
+++ b/llvm/include/llvm/ADT/EquivalenceClasses.h
@@ -17,6 +17,7 @@
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/Support/Allocator.h"
 #include <cassert>
@@ -220,13 +221,18 @@ template <class ElemTy> class EquivalenceClasses {
     return *ECV;
   }
 
-  /// erase - Erase a value from the union/find set, return if erase succeed.
+  /// erase - Erase a value from the union/find set, return "true" if erase
+  /// succeeded.
   bool erase(const ElemTy &V) {
     if (!TheMapping.contains(V))
       return false;
     const ECValue *Cur = TheMapping[V];
     const ECValue *Next = Cur->getNext();
     if (Cur->isLeader()) {
+      // If the current element is the leader and has a successor element,
+      // update the successor element's 'Leader' field to be the last element,
+      // set the successor element's stolen bit, and set the 'Leader' field of
+      // all other elements in same class to be the successor element.
       if (Next) {
         Next->Leader = Cur->Leader;
         Next->Next = (const ECValue *)((intptr_t)Next->Next | (intptr_t)1);
@@ -242,21 +248,26 @@ template <class ElemTy> class EquivalenceClasses {
         Pre = Pre->getNext();
       }
       if (!Next) {
+        // If the current element is the last element(not leader), set the
+        // successor of the current element's predecessor to null, and set
+        // the 'Leader' field of the class leader to the predecessor element.
         Pre->Next = nullptr;
         Leader->Leader = Pre;
       } else {
+        // If the current element is in the middle of class, then simply
+        // connect the predecessor element and the successor element.
         Pre->Next =
             (const ECValue *)((intptr_t)Next | (intptr_t)Pre->isLeader());
         Next->Leader = Pre;
       }
     }
+
+    // Update 'TheMapping' and 'Members'.
+    assert(TheMapping.contains(V) && "Can't find input in TheMapping!");
     TheMapping.erase(V);
-    for (auto I = Members.begin(); I != Members.end(); I++) {
-      if (*I == Cur) {
-        Members.erase(I);
-        break;
-      }
-    }
+    auto I = llvm::find(Members, Cur);
+    assert(I != Members.end() && "Can't find input in members!");
+    Members.erase(I);
     return true;
   }
 



More information about the llvm-commits mailing list