[llvm] r330875 - [ADT] Make filter_iterator support bidirectional iteration

Vedant Kumar via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 25 14:50:09 PDT 2018


Author: vedantk
Date: Wed Apr 25 14:50:09 2018
New Revision: 330875

URL: http://llvm.org/viewvc/llvm-project?rev=330875&view=rev
Log:
[ADT] Make filter_iterator support bidirectional iteration

This makes it possible to reverse a filtered range. For example, here's
a way to visit memory accesses in a BasicBlock in reverse order:

    auto MemInsts = reverse(make_filter_range(BB, [](Instruction &I) {
      return isa<StoreInst>(&I) || isa<LoadInst>(&I);
    }));

    for (auto &MI : MemInsts)
      ...

To implement this functionality, I factored out forward iteration
functionality into filter_iterator_base, and added a specialization of
filter_iterator_impl which supports bidirectional iteration. Thanks to
Tim Shen, Zachary Turner, and others for suggesting this design and
providing feedback! This version of the patch supersedes the original
(https://reviews.llvm.org/D45792).

This was motivated by a problem we encountered in D45657: we'd like to
visit the non-debug-info instructions in a BasicBlock in reverse order.

Testing: check-llvm, check-clang

Differential Revision: https://reviews.llvm.org/D45853

Modified:
    llvm/trunk/include/llvm/ADT/STLExtras.h
    llvm/trunk/unittests/ADT/IteratorTest.cpp
    llvm/trunk/unittests/IR/BasicBlockTest.cpp

Modified: llvm/trunk/include/llvm/ADT/STLExtras.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ADT/STLExtras.h?rev=330875&r1=330874&r2=330875&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ADT/STLExtras.h (original)
+++ llvm/trunk/include/llvm/ADT/STLExtras.h Wed Apr 25 14:50:09 2018
@@ -271,60 +271,121 @@ auto reverse(
 ///   auto R = make_filter_range(A, [](int N) { return N % 2 == 1; });
 ///   // R contains { 1, 3 }.
 /// \endcode
-template <typename WrappedIteratorT, typename PredicateT>
-class filter_iterator
+///
+/// Note: filter_iterator_base implements support for forward iteration.
+/// filter_iterator_impl exists to provide support for bidirectional iteration,
+/// conditional on whether the wrapped iterator supports it.
+template <typename WrappedIteratorT, typename PredicateT, typename IterTag>
+class filter_iterator_base
     : public iterator_adaptor_base<
-          filter_iterator<WrappedIteratorT, PredicateT>, WrappedIteratorT,
+          filter_iterator_base<WrappedIteratorT, PredicateT, IterTag>,
+          WrappedIteratorT,
           typename std::common_type<
-              std::forward_iterator_tag,
-              typename std::iterator_traits<
-                  WrappedIteratorT>::iterator_category>::type> {
+              IterTag, typename std::iterator_traits<
+                           WrappedIteratorT>::iterator_category>::type> {
   using BaseT = iterator_adaptor_base<
-      filter_iterator<WrappedIteratorT, PredicateT>, WrappedIteratorT,
+      filter_iterator_base<WrappedIteratorT, PredicateT, IterTag>,
+      WrappedIteratorT,
       typename std::common_type<
-          std::forward_iterator_tag,
-          typename std::iterator_traits<WrappedIteratorT>::iterator_category>::
-          type>;
-
-  struct PayloadType {
-    WrappedIteratorT End;
-    PredicateT Pred;
-  };
+          IterTag, typename std::iterator_traits<
+                       WrappedIteratorT>::iterator_category>::type>;
 
-  Optional<PayloadType> Payload;
+protected:
+  WrappedIteratorT End;
+  PredicateT Pred;
 
   void findNextValid() {
-    assert(Payload && "Payload should be engaged when findNextValid is called");
-    while (this->I != Payload->End && !Payload->Pred(*this->I))
+    while (this->I != End && !Pred(*this->I))
       BaseT::operator++();
   }
 
-  // Construct the begin iterator. The begin iterator requires to know where end
-  // is, so that it can properly stop when it hits end.
-  filter_iterator(WrappedIteratorT Begin, WrappedIteratorT End, PredicateT Pred)
-      : BaseT(std::move(Begin)),
-        Payload(PayloadType{std::move(End), std::move(Pred)}) {
+  // Construct the iterator. The begin iterator needs to know where the end
+  // is, so that it can properly stop when it gets there. The end iterator only
+  // needs the predicate to support bidirectional iteration.
+  filter_iterator_base(WrappedIteratorT Begin, WrappedIteratorT End,
+                       PredicateT Pred)
+      : BaseT(Begin), End(End), Pred(Pred) {
     findNextValid();
   }
 
-  // Construct the end iterator. It's not incrementable, so Payload doesn't
-  // have to be engaged.
-  filter_iterator(WrappedIteratorT End) : BaseT(End) {}
-
 public:
   using BaseT::operator++;
 
-  filter_iterator &operator++() {
+  filter_iterator_base &operator++() {
     BaseT::operator++();
     findNextValid();
     return *this;
   }
+};
 
-  template <typename RT, typename PT>
-  friend iterator_range<filter_iterator<detail::IterOfRange<RT>, PT>>
-  make_filter_range(RT &&, PT);
+/// Specialization of filter_iterator_base for forward iteration only.
+template <typename WrappedIteratorT, typename PredicateT,
+          typename IterTag = std::forward_iterator_tag>
+class filter_iterator_impl
+    : public filter_iterator_base<WrappedIteratorT, PredicateT, IterTag> {
+  using BaseT = filter_iterator_base<WrappedIteratorT, PredicateT, IterTag>;
+
+public:
+  filter_iterator_impl(WrappedIteratorT Begin, WrappedIteratorT End,
+                       PredicateT Pred)
+      : BaseT(Begin, End, Pred) {}
 };
 
+/// Specialization of filter_iterator_base for bidirectional iteration.
+template <typename WrappedIteratorT, typename PredicateT>
+class filter_iterator_impl<WrappedIteratorT, PredicateT,
+                           std::bidirectional_iterator_tag>
+    : public filter_iterator_base<WrappedIteratorT, PredicateT,
+                                  std::bidirectional_iterator_tag> {
+  using BaseT = filter_iterator_base<WrappedIteratorT, PredicateT,
+                                     std::bidirectional_iterator_tag>;
+  void findPrevValid() {
+    while (!this->Pred(*this->I))
+      BaseT::operator--();
+  }
+
+public:
+  using BaseT::operator--;
+
+  filter_iterator_impl(WrappedIteratorT Begin, WrappedIteratorT End,
+                       PredicateT Pred)
+      : BaseT(Begin, End, Pred) {}
+
+  filter_iterator_impl &operator--() {
+    BaseT::operator--();
+    findPrevValid();
+    return *this;
+  }
+};
+
+namespace detail {
+
+template <bool is_bidirectional> struct fwd_or_bidi_tag_impl {
+  using type = std::forward_iterator_tag;
+};
+
+template <> struct fwd_or_bidi_tag_impl<true> {
+  using type = std::bidirectional_iterator_tag;
+};
+
+/// Helper which sets its type member to forward_iterator_tag if the category
+/// of \p IterT does not derive from bidirectional_iterator_tag, and to
+/// bidirectional_iterator_tag otherwise.
+template <typename IterT> struct fwd_or_bidi_tag {
+  using type = typename fwd_or_bidi_tag_impl<std::is_base_of<
+      std::bidirectional_iterator_tag,
+      typename std::iterator_traits<IterT>::iterator_category>::value>::type;
+};
+
+} // namespace detail
+
+/// Defines filter_iterator to a suitable specialization of
+/// filter_iterator_impl, based on the underlying iterator's category.
+template <typename WrappedIteratorT, typename PredicateT>
+using filter_iterator = filter_iterator_impl<
+    WrappedIteratorT, PredicateT,
+    typename detail::fwd_or_bidi_tag<WrappedIteratorT>::type>;
+
 /// Convenience function that takes a range of elements and a predicate,
 /// and return a new filter_iterator range.
 ///
@@ -337,10 +398,11 @@ iterator_range<filter_iterator<detail::I
 make_filter_range(RangeT &&Range, PredicateT Pred) {
   using FilterIteratorT =
       filter_iterator<detail::IterOfRange<RangeT>, PredicateT>;
-  return make_range(FilterIteratorT(std::begin(std::forward<RangeT>(Range)),
-                                    std::end(std::forward<RangeT>(Range)),
-                                    std::move(Pred)),
-                    FilterIteratorT(std::end(std::forward<RangeT>(Range))));
+  return make_range(
+      FilterIteratorT(std::begin(std::forward<RangeT>(Range)),
+                      std::end(std::forward<RangeT>(Range)), Pred),
+      FilterIteratorT(std::end(std::forward<RangeT>(Range)),
+                      std::end(std::forward<RangeT>(Range)), Pred));
 }
 
 // forward declarations required by zip_shortest/zip_first

Modified: llvm/trunk/unittests/ADT/IteratorTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/ADT/IteratorTest.cpp?rev=330875&r1=330874&r2=330875&view=diff
==============================================================================
--- llvm/trunk/unittests/ADT/IteratorTest.cpp (original)
+++ llvm/trunk/unittests/ADT/IteratorTest.cpp Wed Apr 25 14:50:09 2018
@@ -8,6 +8,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ADT/iterator.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "gtest/gtest.h"
@@ -196,6 +197,33 @@ TEST(FilterIteratorTest, InputIterator)
   EXPECT_EQ((SmallVector<int, 3>{1, 3, 5}), Actual);
 }
 
+TEST(FilterIteratorTest, ReverseFilterRange) {
+  auto IsOdd = [](int N) { return N % 2 == 1; };
+  int A[] = {0, 1, 2, 3, 4, 5, 6};
+
+  // Check basic reversal.
+  auto Range = reverse(make_filter_range(A, IsOdd));
+  SmallVector<int, 3> Actual(Range.begin(), Range.end());
+  EXPECT_EQ((SmallVector<int, 3>{5, 3, 1}), Actual);
+
+  // Check that the reverse of the reverse is the original.
+  auto Range2 = reverse(reverse(make_filter_range(A, IsOdd)));
+  SmallVector<int, 3> Actual2(Range2.begin(), Range2.end());
+  EXPECT_EQ((SmallVector<int, 3>{1, 3, 5}), Actual2);
+
+  // Check empty ranges.
+  auto Range3 = reverse(make_filter_range(ArrayRef<int>(), IsOdd));
+  SmallVector<int, 0> Actual3(Range3.begin(), Range3.end());
+  EXPECT_EQ((SmallVector<int, 0>{}), Actual3);
+
+  // Check that we don't skip the first element, provided it isn't filtered
+  // away.
+  auto IsEven = [](int N) { return N % 2 == 0; };
+  auto Range4 = reverse(make_filter_range(A, IsEven));
+  SmallVector<int, 4> Actual4(Range4.begin(), Range4.end());
+  EXPECT_EQ((SmallVector<int, 4>{6, 4, 2, 0}), Actual4);
+}
+
 TEST(PointerIterator, Basic) {
   int A[] = {1, 2, 3, 4};
   pointer_iterator<int *> Begin(std::begin(A)), End(std::end(A));

Modified: llvm/trunk/unittests/IR/BasicBlockTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/IR/BasicBlockTest.cpp?rev=330875&r1=330874&r2=330875&view=diff
==============================================================================
--- llvm/trunk/unittests/IR/BasicBlockTest.cpp (original)
+++ llvm/trunk/unittests/IR/BasicBlockTest.cpp Wed Apr 25 14:50:09 2018
@@ -69,6 +69,15 @@ TEST(BasicBlockTest, PhiRange) {
   CI = BB->phis().begin();
   EXPECT_NE(CI, BB->phis().end());
 
+  // Test that filtering iterators work with basic blocks.
+  auto isPhi = [](Instruction &I) { return isa<PHINode>(&I); };
+  auto Phis = make_filter_range(*BB, isPhi);
+  auto ReversedPhis = reverse(make_filter_range(*BB, isPhi));
+  EXPECT_EQ(std::distance(Phis.begin(), Phis.end()), 3);
+  EXPECT_EQ(&*Phis.begin(), P1);
+  EXPECT_EQ(std::distance(ReversedPhis.begin(), ReversedPhis.end()), 3);
+  EXPECT_EQ(&*ReversedPhis.begin(), P3);
+
   // And iterate a const range.
   for (const auto &PN : const_cast<const BasicBlock *>(BB.get())->phis()) {
     EXPECT_EQ(BB.get(), PN.getIncomingBlock(0));




More information about the llvm-commits mailing list