[llvm] [SetOperations] Support set containers with remove_if (PR #96613)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 25 07:19:49 PDT 2024


https://github.com/nikic updated https://github.com/llvm/llvm-project/pull/96613

>From 6ee8b4927c1cb8f0bd4585bb2d1695ddd94e2f56 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Tue, 25 Jun 2024 11:32:09 +0200
Subject: [PATCH 1/4] [SetOperations] Support set containers with remove_if

The current set_intersect implementation only works for std::set
style sets that have a value-erase method that does not invalidate
iterators. As such, it cannot be used for set containers like
SetVector, which only has iterator-invalidating erase.

Support such set containers by calling the remove_if method instead,
if it exists. The detection code is adopted from how contains()
is detected inside llvm::is_contained().
---
 llvm/include/llvm/ADT/SetOperations.h    | 27 +++++++++++++++++++-----
 llvm/unittests/ADT/SetOperationsTest.cpp | 11 ++++++++++
 2 files changed, 33 insertions(+), 5 deletions(-)

diff --git a/llvm/include/llvm/ADT/SetOperations.h b/llvm/include/llvm/ADT/SetOperations.h
index 6c04c764e5207..9e125388509bd 100644
--- a/llvm/include/llvm/ADT/SetOperations.h
+++ b/llvm/include/llvm/ADT/SetOperations.h
@@ -15,8 +15,20 @@
 #ifndef LLVM_ADT_SETOPERATIONS_H
 #define LLVM_ADT_SETOPERATIONS_H
 
+#include "llvm/ADT/STLExtras.h"
+
 namespace llvm {
 
+namespace detail {
+template <typename Set, typename Fn>
+using check_has_member_remove_if_t =
+    decltype(std::declval<Set>().remove_if(std::declval<Fn>()));
+
+template <typename Set, typename Fn>
+static constexpr bool HasMemberRemoveIf =
+    is_detected<check_has_member_remove_if_t, Set, Fn>::value;
+} // namespace detail
+
 /// set_union(A, B) - Compute A := A u B, return whether A changed.
 ///
 template <class S1Ty, class S2Ty> bool set_union(S1Ty &S1, const S2Ty &S2) {
@@ -36,11 +48,16 @@ template <class S1Ty, class S2Ty> bool set_union(S1Ty &S1, const S2Ty &S2) {
 /// elements that are not contained in S2.
 ///
 template <class S1Ty, class S2Ty> void set_intersect(S1Ty &S1, const S2Ty &S2) {
-  for (typename S1Ty::iterator I = S1.begin(); I != S1.end();) {
-    const auto &E = *I;
-    ++I;
-    if (!S2.count(E))
-      S1.erase(E); // Erase element if not in S2
+  if constexpr (detail::HasMemberRemoveIf<S1Ty,
+                                          bool (*)(decltype(*S2.begin()))>) {
+    S1.remove_if([S2](const auto &E) { return !S2.count(E); });
+  } else {
+    for (typename S1Ty::iterator I = S1.begin(); I != S1.end();) {
+      const auto &E = *I;
+      ++I;
+      if (!S2.count(E))
+        S1.erase(E); // Erase element if not in S2
+    }
   }
 }
 
diff --git a/llvm/unittests/ADT/SetOperationsTest.cpp b/llvm/unittests/ADT/SetOperationsTest.cpp
index 982ea819fd809..7bd5189e48821 100644
--- a/llvm/unittests/ADT/SetOperationsTest.cpp
+++ b/llvm/unittests/ADT/SetOperationsTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SetVector.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
@@ -65,6 +66,16 @@ TEST(SetOperationsTest, SetIntersect) {
   // is empty as they are non-overlapping.
   EXPECT_THAT(Set1, IsEmpty());
   EXPECT_EQ(ExpectedSet2, Set2);
+
+  // Check that set_intersect works on SetVector via remove_if.
+  SmallSetVector<int, 4> SV;
+  SV.insert(3);
+  SV.insert(6);
+  SV.insert(4);
+  SV.insert(5);
+  set_intersect(SV, Set2);
+  // SV should contain only 6 and 5 now.
+  EXPECT_EQ(SV.getArrayRef(), ArrayRef({6, 5}));
 }
 
 TEST(SetOperationsTest, SetIntersection) {

>From 0e9d0b36614745c65cb786202f2fa73ad607d648 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Tue, 25 Jun 2024 16:04:58 +0200
Subject: [PATCH 2/4] capture by ref

---
 llvm/include/llvm/ADT/SetOperations.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/include/llvm/ADT/SetOperations.h b/llvm/include/llvm/ADT/SetOperations.h
index 9e125388509bd..9c191a7b9223f 100644
--- a/llvm/include/llvm/ADT/SetOperations.h
+++ b/llvm/include/llvm/ADT/SetOperations.h
@@ -50,7 +50,7 @@ template <class S1Ty, class S2Ty> bool set_union(S1Ty &S1, const S2Ty &S2) {
 template <class S1Ty, class S2Ty> void set_intersect(S1Ty &S1, const S2Ty &S2) {
   if constexpr (detail::HasMemberRemoveIf<S1Ty,
                                           bool (*)(decltype(*S2.begin()))>) {
-    S1.remove_if([S2](const auto &E) { return !S2.count(E); });
+    S1.remove_if([&S2](const auto &E) { return !S2.count(E); });
   } else {
     for (typename S1Ty::iterator I = S1.begin(); I != S1.end();) {
       const auto &E = *I;

>From b94e9edad6686256053d8bb77d80609dee3b3e79 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Tue, 25 Jun 2024 16:06:03 +0200
Subject: [PATCH 3/4] store lambda and use decltype

---
 llvm/include/llvm/ADT/SetOperations.h | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/ADT/SetOperations.h b/llvm/include/llvm/ADT/SetOperations.h
index 9c191a7b9223f..1a911b239f4c6 100644
--- a/llvm/include/llvm/ADT/SetOperations.h
+++ b/llvm/include/llvm/ADT/SetOperations.h
@@ -48,9 +48,9 @@ template <class S1Ty, class S2Ty> bool set_union(S1Ty &S1, const S2Ty &S2) {
 /// elements that are not contained in S2.
 ///
 template <class S1Ty, class S2Ty> void set_intersect(S1Ty &S1, const S2Ty &S2) {
-  if constexpr (detail::HasMemberRemoveIf<S1Ty,
-                                          bool (*)(decltype(*S2.begin()))>) {
-    S1.remove_if([&S2](const auto &E) { return !S2.count(E); });
+  auto Pred = [&S2](const auto &E) { return !S2.count(E); };
+  if constexpr (detail::HasMemberRemoveIf<S1Ty, decltype(Pred)>) {
+    S1.remove_if(Pred);
   } else {
     for (typename S1Ty::iterator I = S1.begin(); I != S1.end();) {
       const auto &E = *I;

>From a86218c4f068cdb26f8d96c680756d2e8d7f0f15 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Tue, 25 Jun 2024 16:19:28 +0200
Subject: [PATCH 4/4] Use ElementsAre

---
 llvm/unittests/ADT/SetOperationsTest.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/unittests/ADT/SetOperationsTest.cpp b/llvm/unittests/ADT/SetOperationsTest.cpp
index 7bd5189e48821..b3d931cbfd479 100644
--- a/llvm/unittests/ADT/SetOperationsTest.cpp
+++ b/llvm/unittests/ADT/SetOperationsTest.cpp
@@ -75,7 +75,7 @@ TEST(SetOperationsTest, SetIntersect) {
   SV.insert(5);
   set_intersect(SV, Set2);
   // SV should contain only 6 and 5 now.
-  EXPECT_EQ(SV.getArrayRef(), ArrayRef({6, 5}));
+  EXPECT_THAT(SV, testing::ElementsAre(6, 5));
 }
 
 TEST(SetOperationsTest, SetIntersection) {



More information about the llvm-commits mailing list