[llvm] 0eaacc2 - [ADT] Make llvm::is_contained call member `contains` or `find` when available

Jakub Kuderski via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 15 09:08:24 PDT 2023


Author: Jakub Kuderski
Date: 2023-03-15T12:07:56-04:00
New Revision: 0eaacc25bb98f50cc98bab5f6ef8d6d67e112317

URL: https://github.com/llvm/llvm-project/commit/0eaacc25bb98f50cc98bab5f6ef8d6d67e112317
DIFF: https://github.com/llvm/llvm-project/commit/0eaacc25bb98f50cc98bab5f6ef8d6d67e112317.diff

LOG: [ADT] Make llvm::is_contained call member `contains` or `find` when available

This makes it so that calling `llvm::is_contained` no longer degrades
performance over member contains, even though both have almost identical
names. This would be the case in most set/map classes that can check for
an element being present in O(1) or O(log n) time vs. linear scan with
`std::find`. For C++17 maps/sets without `.contains`, use `.find` when available,
falling back to a linear scan with `std::find`.

I also considered detecting member contains and triggering a
`static_assert` instead, but decided against it because it's just as easy
to do the right thing and call `.contains`. This would also make some code fail
only when compiled in the C++20 mode when more container types come with
`.contains` member functions.

This was actually already the case with `CommandLine.h` calling `is_contained`
on `SmallPtrSet` and in a recent BOLT patch.

Reviewed By: kazu, dblaikie, MaskRay

Differential Revision: https://reviews.llvm.org/D146061

Added: 
    

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    llvm/include/llvm/Analysis/LoopInfoImpl.h
    llvm/unittests/ADT/STLExtrasTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 06a8b86a7feb0..545e888c18230 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1906,11 +1906,40 @@ OutputIt move(R &&Range, OutputIt Out) {
   return std::move(adl_begin(Range), adl_end(Range), Out);
 }
 
-/// Wrapper function around std::find to detect if an element exists
-/// in a container.
+namespace detail {
+template <typename Range, typename Element>
+using check_has_member_contains_t =
+    decltype(std::declval<Range &>().contains(std::declval<const Element &>()));
+
+template <typename Range, typename Element>
+static constexpr bool HasMemberContains =
+    is_detected<check_has_member_contains_t, Range, Element>::value;
+
+template <typename Range, typename Element>
+using check_has_member_find_t =
+    decltype(std::declval<Range &>().find(std::declval<const Element &>()) !=
+             std::declval<Range &>().end());
+
+template <typename Range, typename Element>
+static constexpr bool HasMemberFind =
+    is_detected<check_has_member_find_t, Range, Element>::value;
+
+} // namespace detail
+
+/// Returns true if \p Element is found in \p Range. Delegates the check to
+/// either `.contains(Element)`, `.find(Element)`, or `std::find`, in this
+/// order of preference. This is intended as the canonical way to check if an
+/// element exists in a range in generic code or range type that does not
+/// expose a `.contains(Element)` member.
 template <typename R, typename E>
 bool is_contained(R &&Range, const E &Element) {
-  return std::find(adl_begin(Range), adl_end(Range), Element) != adl_end(Range);
+  if constexpr (detail::HasMemberContains<R, E>)
+    return Range.contains(Element);
+  else if constexpr (detail::HasMemberFind<R, E>)
+    return Range.find(Element) != Range.end();
+  else
+    return std::find(adl_begin(Range), adl_end(Range), Element) !=
+           adl_end(Range);
 }
 
 /// Returns true iff \p Element exists in \p Set. This overload takes \p Set as

diff  --git a/llvm/include/llvm/Analysis/LoopInfoImpl.h b/llvm/include/llvm/Analysis/LoopInfoImpl.h
index c509ee67cbacc..48f6281e48f8b 100644
--- a/llvm/include/llvm/Analysis/LoopInfoImpl.h
+++ b/llvm/include/llvm/Analysis/LoopInfoImpl.h
@@ -371,7 +371,7 @@ void LoopBase<BlockT, LoopT>::verifyLoop() const {
 
   // Check the parent loop pointer.
   if (ParentLoop) {
-    assert(is_contained(*ParentLoop, this) &&
+    assert(is_contained(ParentLoop->getSubLoops(), this) &&
            "Loop is not a subloop of its parent!");
   }
 #endif

diff  --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index f39d0975109a5..8ec3df10bc6c3 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -1029,6 +1029,56 @@ TEST(STLExtrasTest, IsContainedInitializerList) {
   static_assert(!is_contained({1, 2, 3, 4}, 5), "It's not there :(");
 }
 
+TEST(STLExtrasTest, IsContainedMemberContains) {
+  // Check that `llvm::is_contained` uses the member `.contains()` when
+  // available. Check that `.contains()` is preferred over `.find()`.
+  struct Foo {
+    bool contains(int) const {
+      ++NumContainsCalls;
+      return ContainsResult;
+    }
+    int *begin() { return nullptr; }
+    int *end() { return nullptr; }
+    int *find(int) { return nullptr; }
+
+    bool ContainsResult = false;
+    mutable unsigned NumContainsCalls = 0;
+  } Container;
+
+  EXPECT_EQ(Container.NumContainsCalls, 0u);
+  EXPECT_FALSE(is_contained(Container, 1));
+  EXPECT_EQ(Container.NumContainsCalls, 1u);
+
+  Container.ContainsResult = true;
+  EXPECT_TRUE(is_contained(Container, 1));
+  EXPECT_EQ(Container.NumContainsCalls, 2u);
+}
+
+TEST(STLExtrasTest, IsContainedMemberFind) {
+  // Check that `llvm::is_contained` uses the member `.find(x)` when available.
+  struct Foo {
+    auto begin() { return Data.begin(); }
+    auto end() { return Data.end(); }
+    auto find(int X) {
+      ++NumFindCalls;
+      return std::find(begin(), end(), X);
+    }
+
+    std::vector<int> Data;
+    mutable unsigned NumFindCalls = 0;
+  } Container;
+
+  Container.Data = {1, 2, 3};
+
+  EXPECT_EQ(Container.NumFindCalls, 0u);
+  EXPECT_TRUE(is_contained(Container, 1));
+  EXPECT_TRUE(is_contained(Container, 3));
+  EXPECT_EQ(Container.NumFindCalls, 2u);
+
+  EXPECT_FALSE(is_contained(Container, 4));
+  EXPECT_EQ(Container.NumFindCalls, 3u);
+}
+
 TEST(STLExtrasTest, addEnumValues) {
   enum A { Zero = 0, One = 1 };
   enum B { IntMax = INT_MAX, ULongLongMax = ULLONG_MAX };


        


More information about the llvm-commits mailing list