[llvm] [ADT] Make set_subtract more efficient when subtrahend is larger (NFC) (PR #99401)

Kazu Hirata via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 17 16:12:02 PDT 2024


https://github.com/kazutakahirata created https://github.com/llvm/llvm-project/pull/99401

This patch is based on:

  commit fffe2728534a238ff0024e11a18280f85094dcde
  Author: Teresa Johnson <tejohnson at google.com>
  Date:   Wed Jul 17 13:53:10 2024 -0700

This iteration comes with a couple of improvements:

- We now accommodate S2Ty being SmallPtrSet, which has remove_if(pred)
  but not erase(iterator).  (Lack of this code path broke the mlir
  build.)

- The code path for erase(iterator) now pre-increments the iterator to
  avoid problems with iterator invalidation.


>From 231204d6253e7ba619faf0f909990d3fadf32bf5 Mon Sep 17 00:00:00 2001
From: Kazu Hirata <kazu at google.com>
Date: Wed, 17 Jul 2024 15:45:25 -0700
Subject: [PATCH] [ADT] Make set_subtract more efficient when subtrahend is
 larger (NFC)

This patch is based on:

  commit fffe2728534a238ff0024e11a18280f85094dcde
  Author: Teresa Johnson <tejohnson at google.com>
  Date:   Wed Jul 17 13:53:10 2024 -0700

This iteration comes with a couple of improvements:

- We now accommodate S2Ty being SmallPtrSet, which has remove_if(pred)
  but not erase(iterator).  (Lack of this code path broke the mlir
  build.)

- The code path for erase(iterator) now pre-increments the iterator to
  avoid problems with iterator invalidation.
---
 llvm/include/llvm/ADT/SetOperations.h | 37 +++++++++++++++++++++++++++
 1 file changed, 37 insertions(+)

diff --git a/llvm/include/llvm/ADT/SetOperations.h b/llvm/include/llvm/ADT/SetOperations.h
index 1a911b239f4c6..2b1a103565f7d 100644
--- a/llvm/include/llvm/ADT/SetOperations.h
+++ b/llvm/include/llvm/ADT/SetOperations.h
@@ -27,6 +27,15 @@ using check_has_member_remove_if_t =
 template <typename Set, typename Fn>
 static constexpr bool HasMemberRemoveIf =
     is_detected<check_has_member_remove_if_t, Set, Fn>::value;
+
+template <typename Set>
+using check_has_member_erase_iter_t =
+    decltype(std::declval<Set>().erase(std::declval<Set>().begin()));
+
+template <typename Set>
+static constexpr bool HasMemberEraseIter =
+    is_detected<check_has_member_erase_iter_t, Set>::value;
+
 } // namespace detail
 
 /// set_union(A, B) - Compute A := A u B, return whether A changed.
@@ -94,7 +103,35 @@ S1Ty set_difference(const S1Ty &S1, const S2Ty &S2) {
 
 /// set_subtract(A, B) - Compute A := A - B
 ///
+/// Selects the set to iterate based on the relative sizes of A and B for better
+/// efficiency.
+///
 template <class S1Ty, class S2Ty> void set_subtract(S1Ty &S1, const S2Ty &S2) {
+  // If S1 is smaller than S2, iterate on S1 provided that S2 supports efficient
+  // lookups via contains().  Note that a couple callers pass a vector for S2,
+  // which doesn't support contains(), and wouldn't be efficient if it did.
+  using ElemTy = decltype(*S1.begin());
+  if constexpr (detail::HasMemberContains<S2Ty, ElemTy>) {
+    auto Pred = [&S2](const auto &E) { return S2.contains(E); };
+    if constexpr (detail::HasMemberRemoveIf<S1Ty, decltype(Pred)>) {
+      if (S1.size() < S2.size()) {
+        S1.remove_if(Pred);
+        return;
+      }
+    } else if constexpr (detail::HasMemberEraseIter<S1Ty>) {
+      if (S1.size() < S2.size()) {
+        typename S1Ty::iterator Next;
+        for (typename S1Ty::iterator SI = S1.begin(), SE = S1.end(); SI != SE;
+             SI = Next) {
+          Next = std::next(SI);
+          if (S2.contains(*SI))
+            S1.erase(SI);
+        }
+        return;
+      }
+    }
+  }
+
   for (typename S2Ty::const_iterator SI = S2.begin(), SE = S2.end(); SI != SE;
        ++SI)
     S1.erase(*SI);



More information about the llvm-commits mailing list