[llvm] [ADT] Make set_subtract more efficient when subtrahend is larger (NFC) (PR #99401)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 17 16:12:32 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-adt
Author: Kazu Hirata (kazutakahirata)
<details>
<summary>Changes</summary>
This patch is based on:
commit fffe2728534a238ff0024e11a18280f85094dcde
Author: Teresa Johnson <tejohnson@<!-- -->google.com>
Date: Wed Jul 17 13:53:10 2024 -0700
This iteration comes with a couple of improvements:
- We now accommodate S2Ty being SmallPtrSet, which has remove_if(pred)
but not erase(iterator). (Lack of this code path broke the mlir
build.)
- The code path for erase(iterator) now pre-increments the iterator to
avoid problems with iterator invalidation.
---
Full diff: https://github.com/llvm/llvm-project/pull/99401.diff
1 Files Affected:
- (modified) llvm/include/llvm/ADT/SetOperations.h (+37)
``````````diff
diff --git a/llvm/include/llvm/ADT/SetOperations.h b/llvm/include/llvm/ADT/SetOperations.h
index 1a911b239f4c6..2b1a103565f7d 100644
--- a/llvm/include/llvm/ADT/SetOperations.h
+++ b/llvm/include/llvm/ADT/SetOperations.h
@@ -27,6 +27,15 @@ using check_has_member_remove_if_t =
template <typename Set, typename Fn>
static constexpr bool HasMemberRemoveIf =
is_detected<check_has_member_remove_if_t, Set, Fn>::value;
+
+template <typename Set>
+using check_has_member_erase_iter_t =
+ decltype(std::declval<Set>().erase(std::declval<Set>().begin()));
+
+template <typename Set>
+static constexpr bool HasMemberEraseIter =
+ is_detected<check_has_member_erase_iter_t, Set>::value;
+
} // namespace detail
/// set_union(A, B) - Compute A := A u B, return whether A changed.
@@ -94,7 +103,35 @@ S1Ty set_difference(const S1Ty &S1, const S2Ty &S2) {
/// set_subtract(A, B) - Compute A := A - B
///
+/// Selects the set to iterate based on the relative sizes of A and B for better
+/// efficiency.
+///
template <class S1Ty, class S2Ty> void set_subtract(S1Ty &S1, const S2Ty &S2) {
+ // If S1 is smaller than S2, iterate on S1 provided that S2 supports efficient
+ // lookups via contains(). Note that a couple callers pass a vector for S2,
+ // which doesn't support contains(), and wouldn't be efficient if it did.
+ using ElemTy = decltype(*S1.begin());
+ if constexpr (detail::HasMemberContains<S2Ty, ElemTy>) {
+ auto Pred = [&S2](const auto &E) { return S2.contains(E); };
+ if constexpr (detail::HasMemberRemoveIf<S1Ty, decltype(Pred)>) {
+ if (S1.size() < S2.size()) {
+ S1.remove_if(Pred);
+ return;
+ }
+ } else if constexpr (detail::HasMemberEraseIter<S1Ty>) {
+ if (S1.size() < S2.size()) {
+ typename S1Ty::iterator Next;
+ for (typename S1Ty::iterator SI = S1.begin(), SE = S1.end(); SI != SE;
+ SI = Next) {
+ Next = std::next(SI);
+ if (S2.contains(*SI))
+ S1.erase(SI);
+ }
+ return;
+ }
+ }
+ }
+
for (typename S2Ty::const_iterator SI = S2.begin(), SE = S2.end(); SI != SE;
++SI)
S1.erase(*SI);
``````````
</details>
https://github.com/llvm/llvm-project/pull/99401
More information about the llvm-commits
mailing list