[llvm] [EquivalenceClasses] Introduce erase member function (PR #134660)

donald chen via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 7 07:40:01 PDT 2025


https://github.com/cxy-1993 created https://github.com/llvm/llvm-project/pull/134660

Introduce 'erase(const ElemTy &V)' member function to allow the deletion of a certain value from EquivClasses. This is essential for certain scenarios that require modifying the contents of EquivClasses.

This path also incidentally fixes a problem of inaccurate leader setting when EquivClasses unions two classes. This problem only arises in the presence of erase function.

>From b3f964bb3b38e9248f938e204c0090c986b29586 Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Mon, 7 Apr 2025 14:29:57 +0000
Subject: [PATCH] [EquivalenceClasses] Introduce erase member function

Introduce 'erase(const ElemTy &V)' member function to allow the
deletion of a certain value from EquivClasses. This is essential for
certain scenarios that require modifying the contents of EquivClasses.

This path also incidentally fixes a problem of inaccurate leader
setting when EquivClasses unions two classes. This problem only arises
in the presence of erase function.
---
 llvm/include/llvm/ADT/EquivalenceClasses.h    | 42 ++++++++++++++--
 llvm/unittests/ADT/EquivalenceClassesTest.cpp | 49 +++++++++++++++++++
 2 files changed, 88 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/ADT/EquivalenceClasses.h b/llvm/include/llvm/ADT/EquivalenceClasses.h
index e0a7af9421c35..e2e6b626993ef 100644
--- a/llvm/include/llvm/ADT/EquivalenceClasses.h
+++ b/llvm/include/llvm/ADT/EquivalenceClasses.h
@@ -220,6 +220,40 @@ template <class ElemTy> class EquivalenceClasses {
     return *ECV;
   }
 
+  /// erase - Erase a value from the union/find set, return if erase succeed.
+  bool erase(const ElemTy &V) {
+    if (!TheMapping.contains(V)) return false;
+    const ECValue *cur = TheMapping[V];
+    const ECValue *next = cur->getNext();
+    if (cur->isLeader()) {
+      if (next) {
+        next->Leader = cur->Leader;
+        auto nn = next->Next;
+        next->Next = (const ECValue*)((intptr_t)nn | (intptr_t)1);
+      }
+    } else {
+      const ECValue *leader = findLeader(V).Node;
+      const ECValue *pre = leader;
+      while (pre->getNext() != cur) {
+        pre = pre->getNext();
+      }
+      if (!next) {
+        pre->Next = nullptr;
+        leader->Leader = pre;
+      } else {
+        pre->Next = (const ECValue*)((intptr_t)next | (intptr_t)pre->isLeader());
+        next->Leader = pre;
+      }
+    }
+    TheMapping.erase(V);
+    for (auto I = Members.begin(); I != Members.end(); I++) {
+      if (*I == cur) {
+        Members.erase(I);
+        break;
+      }
+    }
+    return true;
+  }
   /// findLeader - Given a value in the set, return a member iterator for the
   /// equivalence class it is in.  This does the path-compression part that
   /// makes union-find "union findy".  This returns an end iterator if the value
@@ -247,7 +281,9 @@ template <class ElemTy> class EquivalenceClasses {
     // Otherwise, this is a real union operation.  Set the end of the L1 list to
     // point to the L2 leader node.
     const ECValue &L1LV = *L1.Node, &L2LV = *L2.Node;
-    L1LV.getEndOfList()->setNext(&L2LV);
+    const ECValue *L1LastV = L1LV.getEndOfList();
+
+    L1LastV->setNext(&L2LV);
 
     // Update L1LV's end of list pointer.
     L1LV.Leader = L2LV.getEndOfList();
@@ -255,8 +291,8 @@ template <class ElemTy> class EquivalenceClasses {
     // Clear L2's leader flag:
     L2LV.Next = L2LV.getNext();
 
-    // L2's leader is now L1.
-    L2LV.Leader = &L1LV;
+    // L2's leader is now last value of L1.
+    L2LV.Leader = L1LastV;
     return L1;
   }
 
diff --git a/llvm/unittests/ADT/EquivalenceClassesTest.cpp b/llvm/unittests/ADT/EquivalenceClassesTest.cpp
index ff243f51102fb..3d5c48eb8e1b6 100644
--- a/llvm/unittests/ADT/EquivalenceClassesTest.cpp
+++ b/llvm/unittests/ADT/EquivalenceClassesTest.cpp
@@ -59,6 +59,55 @@ TEST(EquivalenceClassesTest, SimpleMerge2) {
       EXPECT_TRUE(EqClasses.isEquivalent(i, j));
 }
 
+TEST(EquivalenceClassesTest, SimpleErase1) {
+  EquivalenceClasses<int> EqClasses;
+  // Check that erase head success.
+  // After erase A from (A, B ,C, D), <B, C, D> belong to one set.
+  EqClasses.unionSets(0, 1);
+  EqClasses.unionSets(2, 3);
+  EqClasses.unionSets(0, 2);
+  EXPECT_TRUE(EqClasses.erase(0));
+  for (int i = 1; i < 4; ++i)
+    for (int j = 1; j < 4; ++j)
+      EXPECT_TRUE(EqClasses.isEquivalent(i, j));
+}
+
+TEST(EquivalenceClassesTest, SimpleErase2) {
+  EquivalenceClasses<int> EqClasses;
+  // Check that erase tail success.
+  // After erase D from (A, B ,C, D), <A, B, C> belong to one set.
+  EqClasses.unionSets(0, 1);
+  EqClasses.unionSets(2, 3);
+  EqClasses.unionSets(0, 2);
+  EXPECT_TRUE(EqClasses.erase(3));
+  for (int i = 0; i < 3; ++i)
+    for (int j = 0; j < 3; ++j)
+      EXPECT_TRUE(EqClasses.isEquivalent(i, j));
+}
+
+TEST(EquivalenceClassesTest, SimpleErase3) {
+  EquivalenceClasses<int> EqClasses;
+  // Check that erase a value in the middle success.
+  // After erase B from (A, B ,C, D), <A, C, D> belong to one set.
+  EqClasses.unionSets(0, 1);
+  EqClasses.unionSets(2, 3);
+  EqClasses.unionSets(0, 2);
+  EXPECT_TRUE(EqClasses.erase(1));
+  for (int i = 0; i < 3; ++i)
+    for (int j = 0; j < 3; ++j)
+      EXPECT_TRUE(EqClasses.isEquivalent(i, j) ^ ((i == 1) ^ (j == 1)));
+}
+
+TEST(EquivalenceClassesTest, SimpleErase4) {
+  EquivalenceClasses<int> EqClasses;
+  // Check that erase a single class success.
+  EqClasses.insert(0);
+  EXPECT_TRUE(EqClasses.getNumClasses() == 1);
+  EXPECT_TRUE(EqClasses.erase(0));
+  EXPECT_TRUE(EqClasses.getNumClasses() == 0);
+  EXPECT_FALSE(EqClasses.erase(1));
+}
+
 TEST(EquivalenceClassesTest, TwoSets) {
   EquivalenceClasses<int> EqClasses;
   // Form sets of odd and even numbers, check that we split them into these



More information about the llvm-commits mailing list