[clang] [llvm] [ImmutableSet] Optimize add/remove operations to avoid redundant tree modifications (PR #159845)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 19 13:40:02 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang-temporal-safety
Author: Utkarsh Saxena (usx95)
<details>
<summary>Changes</summary>
Optimize ImmutableSet operations to avoid unnecessary tree modifications when adding existing elements or removing non-existent elements.
- Modified `ImutAVLFactory::add_internal()` to return the original tree when both key and value are the same, avoiding unnecessary node creation
- Updated `ImutAVLFactory::remove_internal()` and `add_internal()` to return the original tree when no changes are made.
Note that `balanceTree` always end up creating at least one node even when no rebalancing is done. So we also need to avoid unnecessary calls to it.
---
Full diff: https://github.com/llvm/llvm-project/pull/159845.diff
3 Files Affected:
- (modified) clang/lib/Analysis/LifetimeSafety.cpp (+4-8)
- (modified) llvm/include/llvm/ADT/ImmutableSet.h (+34-14)
- (modified) llvm/unittests/ADT/ImmutableSetTest.cpp (+33)
``````````diff
diff --git a/clang/lib/Analysis/LifetimeSafety.cpp b/clang/lib/Analysis/LifetimeSafety.cpp
index d016c6f12e82e..0dd5716d93fb6 100644
--- a/clang/lib/Analysis/LifetimeSafety.cpp
+++ b/clang/lib/Analysis/LifetimeSafety.cpp
@@ -910,13 +910,10 @@ template <typename T>
static llvm::ImmutableSet<T> join(llvm::ImmutableSet<T> A,
llvm::ImmutableSet<T> B,
typename llvm::ImmutableSet<T>::Factory &F) {
- if (A == B)
- return A;
if (A.getHeight() < B.getHeight())
std::swap(A, B);
for (const T &E : B)
- if (!A.contains(E))
- A = F.add(A, E);
+ A = F.add(A, E);
return A;
}
@@ -950,11 +947,10 @@ join(llvm::ImmutableMap<K, V> A, llvm::ImmutableMap<K, V> B,
for (const auto &Entry : B) {
const K &Key = Entry.first;
const V &ValB = Entry.second;
- const V *ValA = A.lookup(Key);
- if (!ValA)
- A = F.add(A, Key, ValB);
- else if (*ValA != ValB)
+ if (const V *ValA = A.lookup(Key))
A = F.add(A, Key, JoinValues(*ValA, ValB));
+ else
+ A = F.add(A, Key, ValB);
}
return A;
}
diff --git a/llvm/include/llvm/ADT/ImmutableSet.h b/llvm/include/llvm/ADT/ImmutableSet.h
index ac86f43b2048e..c2d84d86b5e27 100644
--- a/llvm/include/llvm/ADT/ImmutableSet.h
+++ b/llvm/include/llvm/ADT/ImmutableSet.h
@@ -531,7 +531,7 @@ class ImutAVLFactory {
/// add_internal - Creates a new tree that includes the specified
/// data and the data from the original tree. If the original tree
/// already contained the data item, the original tree is returned.
- TreeTy* add_internal(value_type_ref V, TreeTy* T) {
+ TreeTy *add_internal(value_type_ref V, TreeTy *T) {
if (isEmpty(T))
return createNode(T, V, T);
assert(!T->isMutable());
@@ -539,19 +539,34 @@ class ImutAVLFactory {
key_type_ref K = ImutInfo::KeyOfValue(V);
key_type_ref KCurrent = ImutInfo::KeyOfValue(getValue(T));
- if (ImutInfo::isEqual(K,KCurrent))
+ if (ImutInfo::isEqual(K, KCurrent)) {
+ // If both key and value are same, return the original tree.
+ if (ImutInfo::isDataEqual(ImutInfo::DataOfValue(V),
+ ImutInfo::DataOfValue(getValue(T))))
+ return T;
+ // Otherwise create a new node with the new value.
return createNode(getLeft(T), V, getRight(T));
- else if (ImutInfo::isLess(K,KCurrent))
- return balanceTree(add_internal(V, getLeft(T)), getValue(T), getRight(T));
+ }
+
+ TreeTy *NewL = getLeft(T);
+ TreeTy *NewR = getRight(T);
+ if (ImutInfo::isLess(K, KCurrent))
+ NewL = add_internal(V, NewL);
else
- return balanceTree(getLeft(T), getValue(T), add_internal(V, getRight(T)));
+ NewR = add_internal(V, NewR);
+
+ // If no changes were made, return the original tree. Otherwise, balance the
+ // tree and return the new root.
+ return NewL == getLeft(T) && NewR == getRight(T)
+ ? T
+ : balanceTree(NewL, getValue(T), NewR);
}
/// remove_internal - Creates a new tree that includes all the data
/// from the original tree except the specified data. If the
/// specified data did not exist in the original tree, the original
/// tree is returned.
- TreeTy* remove_internal(key_type_ref K, TreeTy* T) {
+ TreeTy* remove_internal(key_type_ref K, TreeTy* T) {
if (isEmpty(T))
return T;
@@ -559,15 +574,20 @@ class ImutAVLFactory {
key_type_ref KCurrent = ImutInfo::KeyOfValue(getValue(T));
- if (ImutInfo::isEqual(K,KCurrent)) {
+ if (ImutInfo::isEqual(K, KCurrent))
return combineTrees(getLeft(T), getRight(T));
- } else if (ImutInfo::isLess(K,KCurrent)) {
- return balanceTree(remove_internal(K, getLeft(T)),
- getValue(T), getRight(T));
- } else {
- return balanceTree(getLeft(T), getValue(T),
- remove_internal(K, getRight(T)));
- }
+
+ TreeTy *NewL = getLeft(T);
+ TreeTy *NewR = getRight(T);
+ if (ImutInfo::isLess(K, KCurrent))
+ NewL = remove_internal(K, NewL);
+ else
+ NewR = remove_internal(K, NewR);
+
+ // If no changes were made, return the original tree. Otherwise, balance the
+ // tree and return the new root.
+ return NewL == getLeft(T) && NewR == getRight(T) ? T
+ : balanceTree(NewL, getValue(T), NewR);
}
TreeTy* combineTrees(TreeTy* L, TreeTy* R) {
diff --git a/llvm/unittests/ADT/ImmutableSetTest.cpp b/llvm/unittests/ADT/ImmutableSetTest.cpp
index c0bde4c4d680b..c85a642d06eb2 100644
--- a/llvm/unittests/ADT/ImmutableSetTest.cpp
+++ b/llvm/unittests/ADT/ImmutableSetTest.cpp
@@ -164,4 +164,37 @@ TEST_F(ImmutableSetTest, IterLongSetTest) {
ASSERT_EQ(6, i);
}
+TEST_F(ImmutableSetTest, AddIfNotFoundTest) {
+ ImmutableSet<long>::Factory f(/*canonicalize=*/false);
+ ImmutableSet<long> S = f.getEmptySet();
+ S = f.add(S, 1);
+ S = f.add(S, 2);
+ S = f.add(S, 3);
+
+ ImmutableSet<long> T1 = f.add(S, 1);
+ ImmutableSet<long> T2 = f.add(S, 2);
+ ImmutableSet<long> T3 = f.add(S, 3);
+ EXPECT_EQ(S.getRoot(), T1.getRoot());
+ EXPECT_EQ(S.getRoot(), T2.getRoot());
+ EXPECT_EQ(S.getRoot(), T3.getRoot());
+
+ ImmutableSet<long> U = f.add(S, 4);
+ EXPECT_NE(S.getRoot(), U.getRoot());
+}
+
+
+TEST_F(ImmutableSetTest, RemoveIfNotFoundTest) {
+ ImmutableSet<long>::Factory f(/*canonicalize=*/false);
+ ImmutableSet<long> S = f.getEmptySet();
+ S = f.add(S, 1);
+ S = f.add(S, 2);
+ S = f.add(S, 3);
+
+ ImmutableSet<long> T = f.remove(S, 4);
+ EXPECT_EQ(S.getRoot(), T.getRoot());
+
+ ImmutableSet<long> U = f.remove(S, 3);
+ EXPECT_NE(S.getRoot(), U.getRoot());
+}
+
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/159845
More information about the llvm-commits
mailing list