[llvm] e772216 - [llvm] Make Sequence reverse-iterable
Guillaume Chatelet via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 8 06:19:15 PDT 2021
Author: Guillaume Chatelet
Date: 2021-06-08T13:18:57Z
New Revision: e772216e708937988c039420d2c559568f91ae27
URL: https://github.com/llvm/llvm-project/commit/e772216e708937988c039420d2c559568f91ae27
DIFF: https://github.com/llvm/llvm-project/commit/e772216e708937988c039420d2c559568f91ae27.diff
LOG: [llvm] Make Sequence reverse-iterable
This patch simplifies the implementation of Sequence and makes it compatible with llvm::reverse.
It exposes the reverse iterators through rbegin/rend which prevents a dangling reference in std::reverse_iterator::operator++().
Differential Revision: https://reviews.llvm.org/D102679
Added:
Modified:
llvm/include/llvm/ADT/Sequence.h
llvm/unittests/ADT/SequenceTest.cpp
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/Sequence.h b/llvm/include/llvm/ADT/Sequence.h
index 8a695d75f77ae..e75527fbaefd4 100644
--- a/llvm/include/llvm/ADT/Sequence.h
+++ b/llvm/include/llvm/ADT/Sequence.h
@@ -15,71 +15,167 @@
#ifndef LLVM_ADT_SEQUENCE_H
#define LLVM_ADT_SEQUENCE_H
-#include "llvm/ADT/iterator.h"
-#include "llvm/ADT/iterator_range.h"
-#include <algorithm>
#include <iterator>
-#include <utility>
namespace llvm {
namespace detail {
-template <typename ValueT>
-class value_sequence_iterator
- : public iterator_facade_base<value_sequence_iterator<ValueT>,
- std::random_access_iterator_tag,
- const ValueT> {
- using BaseT = typename value_sequence_iterator::iterator_facade_base;
-
- ValueT Value;
+template <typename T, bool IsReversed> struct iota_range_iterator {
+ using iterator_category = std::random_access_iterator_tag;
+ using value_type = T;
+ using
diff erence_type = ptr
diff _t;
+ using pointer = T *;
+ using reference = T &;
+
+private:
+ struct Forward {
+ static void increment(T &V) { ++V; }
+ static void decrement(T &V) { --V; }
+ static void offset(T &V,
diff erence_type Offset) { V += Offset; }
+ static T add(const T &V,
diff erence_type Offset) { return V + Offset; }
+ static
diff erence_type
diff erence(const T &A, const T &B) { return A - B; }
+ };
+
+ struct Reverse {
+ static void increment(T &V) { --V; }
+ static void decrement(T &V) { ++V; }
+ static void offset(T &V,
diff erence_type Offset) { V -= Offset; }
+ static T add(const T &V,
diff erence_type Offset) { return V - Offset; }
+ static
diff erence_type
diff erence(const T &A, const T &B) { return B - A; }
+ };
+
+ using Op = std::conditional_t<!IsReversed, Forward, Reverse>;
public:
- using
diff erence_type = typename BaseT::
diff erence_type;
- using reference = typename BaseT::reference;
-
- value_sequence_iterator() = default;
- value_sequence_iterator(const value_sequence_iterator &) = default;
- value_sequence_iterator(value_sequence_iterator &&Arg)
- : Value(std::move(Arg.Value)) {}
- value_sequence_iterator &operator=(const value_sequence_iterator &Arg) {
- Value = Arg.Value;
- return *this;
+ // default-constructible
+ iota_range_iterator() = default;
+ // copy-constructible
+ iota_range_iterator(const iota_range_iterator &) = default;
+ // value constructor
+ explicit iota_range_iterator(T Value) : Value(Value) {}
+ // copy-assignable
+ iota_range_iterator &operator=(const iota_range_iterator &) = default;
+ // destructible
+ ~iota_range_iterator() = default;
+
+ // Can be compared for equivalence using the equality/inequality operators,
+ bool operator!=(const iota_range_iterator &RHS) const {
+ return Value != RHS.Value;
+ }
+ bool operator==(const iota_range_iterator &RHS) const {
+ return Value == RHS.Value;
+ }
+
+ // Comparison
+ bool operator<(const iota_range_iterator &Other) const {
+ return Op::
diff erence(Value, Other.Value) < 0;
+ }
+ bool operator<=(const iota_range_iterator &Other) const {
+ return Op::
diff erence(Value, Other.Value) <= 0;
+ }
+ bool operator>(const iota_range_iterator &Other) const {
+ return Op::
diff erence(Value, Other.Value) > 0;
+ }
+ bool operator>=(const iota_range_iterator &Other) const {
+ return Op::
diff erence(Value, Other.Value) >= 0;
}
- template <typename U, typename Enabler = decltype(ValueT(std::declval<U>()))>
- value_sequence_iterator(U &&Value) : Value(std::forward<U>(Value)) {}
+ // Dereference
+ T operator*() const { return Value; }
+ T operator[](
diff erence_type Offset) const { return Op::add(Value, Offset); }
+
+ // Arithmetic
+ iota_range_iterator operator+(
diff erence_type Offset) const {
+ return {Op::add(Value, Offset)};
+ }
+ iota_range_iterator operator-(
diff erence_type Offset) const {
+ return {Op::add(Value, -Offset)};
+ }
+
+ // Iterator
diff erence
+
diff erence_type operator-(const iota_range_iterator &Other) const {
+ return Op::
diff erence(Value, Other.Value);
+ }
- value_sequence_iterator &operator+=(
diff erence_type N) {
- Value += N;
+ // Pre/Post Increment
+ iota_range_iterator &operator++() {
+ Op::increment(Value);
return *this;
}
- value_sequence_iterator &operator-=(
diff erence_type N) {
- Value -= N;
+ iota_range_iterator operator++(int) {
+ iota_range_iterator Tmp = *this;
+ Op::increment(Value);
+ return Tmp;
+ }
+
+ // Pre/Post Decrement
+ iota_range_iterator &operator--() {
+ Op::decrement(Value);
return *this;
}
- using BaseT::operator-;
-
diff erence_type operator-(const value_sequence_iterator &RHS) const {
- return Value - RHS.Value;
+ iota_range_iterator operator--(int) {
+ iota_range_iterator Tmp = *this;
+ Op::decrement(Value);
+ return Tmp;
}
- bool operator==(const value_sequence_iterator &RHS) const {
- return Value == RHS.Value;
+ // Compound assignment operators
+ iota_range_iterator &operator+=(
diff erence_type Offset) {
+ Op::offset(Value, Offset);
+ return *this;
}
- bool operator<(const value_sequence_iterator &RHS) const {
- return Value < RHS.Value;
+ iota_range_iterator &operator-=(
diff erence_type Offset) {
+ Op::offset(Value, -Offset);
+ return *this;
}
- reference operator*() const { return Value; }
+private:
+ T Value;
};
-} // end namespace detail
+} // namespace detail
+
+template <typename ValueT> struct iota_range {
+ static_assert(std::is_integral<ValueT>::value,
+ "ValueT must be an integral type");
+
+ using value_type = ValueT;
+ using reference = ValueT &;
+ using const_reference = const ValueT &;
+ using iterator = detail::iota_range_iterator<value_type, false>;
+ using const_iterator = iterator;
+ using reverse_iterator = detail::iota_range_iterator<value_type, true>;
+ using const_reverse_iterator = reverse_iterator;
+ using
diff erence_type = typename iterator::
diff erence_type;
+ using size_type = std::size_t;
+
+ value_type Begin;
+ value_type End;
+
+ template <
+ typename BeginT, typename EndT,
+ std::enable_if_t<std::is_convertible<BeginT, ValueT>::value, bool> = true,
+ std::enable_if_t<std::is_convertible<EndT, ValueT>::value, bool> = true>
+ iota_range(BeginT &&Begin, EndT &&End)
+ : Begin(std::forward<BeginT>(Begin)), End(std::forward<EndT>(End)) {}
+
+ size_t size() const { return End - Begin; }
+ bool empty() const { return Begin == End; }
+
+ auto begin() const { return const_iterator(Begin); }
+ auto end() const { return const_iterator(End); }
+
+ auto rbegin() const { return const_reverse_iterator(End - 1); }
+ auto rend() const { return const_reverse_iterator(Begin - 1); }
+
+private:
+ static_assert(std::is_same<ValueT, std::remove_cv_t<ValueT>>::value,
+ "ValueT must not be const nor volatile");
+};
-template <typename ValueT>
-iterator_range<detail::value_sequence_iterator<ValueT>> seq(ValueT Begin,
- ValueT End) {
- return make_range(detail::value_sequence_iterator<ValueT>(Begin),
- detail::value_sequence_iterator<ValueT>(End));
+template <typename ValueT> auto seq(ValueT Begin, ValueT End) {
+ return iota_range<ValueT>(std::move(Begin), std::move(End));
}
} // end namespace llvm
diff --git a/llvm/unittests/ADT/SequenceTest.cpp b/llvm/unittests/ADT/SequenceTest.cpp
index 4356bb18a0cd9..0873b37f9b642 100644
--- a/llvm/unittests/ADT/SequenceTest.cpp
+++ b/llvm/unittests/ADT/SequenceTest.cpp
@@ -15,26 +15,37 @@ using namespace llvm;
namespace {
-TEST(SequenceTest, Basic) {
- int x = 0;
- for (int i : seq(0, 10)) {
- EXPECT_EQ(x, i);
- x++;
+TEST(SequenceTest, Forward) {
+ int X = 0;
+ for (int I : seq(0, 10)) {
+ EXPECT_EQ(X, I);
+ ++X;
}
- EXPECT_EQ(10, x);
+ EXPECT_EQ(10, X);
+}
- auto my_seq = seq(0, 4);
- EXPECT_EQ(4, my_seq.end() - my_seq.begin());
- for (int i : {0, 1, 2, 3})
- EXPECT_EQ(i, (int)my_seq.begin()[i]);
+TEST(SequenceTest, Backward) {
+ int X = 9;
+ for (int I : reverse(seq(0, 10))) {
+ EXPECT_EQ(X, I);
+ --X;
+ }
+ EXPECT_EQ(-1, X);
+}
- EXPECT_TRUE(my_seq.begin() < my_seq.end());
+TEST(SequenceTest, Distance) {
+ const auto Forward = seq(0, 10);
+ EXPECT_EQ(std::distance(Forward.begin(), Forward.end()), 10);
+ EXPECT_EQ(std::distance(Forward.rbegin(), Forward.rend()), 10);
+}
- auto adjusted_begin = my_seq.begin() + 2;
- auto adjusted_end = my_seq.end() - 2;
- EXPECT_TRUE(adjusted_begin == adjusted_end);
- EXPECT_EQ(2, *adjusted_begin);
- EXPECT_EQ(2, *adjusted_end);
+TEST(SequenceTest, Dereferene) {
+ const auto Forward = seq(0, 10).begin();
+ EXPECT_EQ(Forward[0], 0);
+ EXPECT_EQ(Forward[2], 2);
+ const auto Backward = seq(0, 10).rbegin();
+ EXPECT_EQ(Backward[0], 9);
+ EXPECT_EQ(Backward[2], 7);
}
} // anonymous namespace
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 3d1d77cd89e9f..98a49c1bb6d8f 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -707,7 +707,7 @@ def Builtin_SparseElementsAttr
let extraClassDeclaration = [{
template <typename T>
using iterator =
- llvm::mapped_iterator<llvm::detail::value_sequence_iterator<ptr
diff _t>,
+ llvm::mapped_iterator<decltype(llvm::seq<ptr
diff _t>(0, 0))::const_iterator,
std::function<T(ptr
diff _t)>>;
/// Return the values of this attribute in the form of the given type 'T'.
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index abeaa20646287..197a3a059d622 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -380,8 +380,8 @@ void PatternLowering::generateSwitch(SwitchNode *switchNode,
if (kind == Predicates::OperandCountAtLeastQuestion ||
kind == Predicates::ResultCountAtLeastQuestion) {
// Order the children such that the cases are in reverse numerical order.
- SmallVector<unsigned> sortedChildren(
- llvm::seq<unsigned>(0, switchNode->getChildren().size()));
+ auto sequence = llvm::seq<unsigned>(0, switchNode->getChildren().size());
+ SmallVector<unsigned> sortedChildren(sequence.begin(), sequence.end());
llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b52059d535cf1..6e8264143e986 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -302,8 +302,8 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
// Compute the static loop sizes of the index op.
auto targetShape = linalgOp.computeStaticLoopSizes();
// Compute a one-dimensional index vector for the index op dimension.
- SmallVector<int64_t> constantSeq(
- llvm::seq<int64_t>(0, targetShape[indexOp.dim()]));
+ auto seq1 = llvm::seq<int64_t>(0, targetShape[indexOp.dim()]);
+ SmallVector<int64_t> constantSeq(seq1.begin(), seq1.end());
ConstantOp constantOp =
b.create<ConstantOp>(loc, b.getIndexVectorAttr(constantSeq));
// Return the one-dimensional index vector if it lives in the trailing
@@ -317,8 +317,8 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
std::swap(targetShape[indexOp.dim()], targetShape.back());
auto broadCastOp = b.create<vector::BroadcastOp>(
loc, VectorType::get(targetShape, b.getIndexType()), constantOp);
- SmallVector<int64_t> transposition(
- llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
+ auto seq2 = llvm::seq<int64_t>(0, linalgOp.getNumLoops());
+ SmallVector<int64_t> transposition(seq2.begin(), seq2.end());
std::swap(transposition.back(), transposition[indexOp.dim()]);
auto transposeOp =
b.create<vector::TransposeOp>(loc, broadCastOp, transposition);
More information about the llvm-commits
mailing list