[Mlir-commits] [mlir] 6de6131 - [mlir] Optimize usage of llvm::mapped_iterator
River Riddle
llvmlistbot at llvm.org
Wed Nov 10 19:44:54 PST 2021
Author: River Riddle
Date: 2021-11-11T03:26:29Z
New Revision: 6de6131f029d597188b05e666fb77fb69e5b936c
URL: https://github.com/llvm/llvm-project/commit/6de6131f029d597188b05e666fb77fb69e5b936c
DIFF: https://github.com/llvm/llvm-project/commit/6de6131f029d597188b05e666fb77fb69e5b936c.diff
LOG: [mlir] Optimize usage of llvm::mapped_iterator
mapped_iterator is a useful abstraction for applying a
map function over an existing iterator, but our current
usage ends up allocating storage/making indirect calls
even with the map function is a known function, which
is horribly inefficient. This commit refactors the usage
of mapped_iterator to avoid this, and allows for directly
referencing the map function when dereferencing.
Fixes PR52319
Differential Revision: https://reviews.llvm.org/D113511
Added:
Modified:
llvm/include/llvm/ADT/STLExtras.h
llvm/unittests/ADT/MappedIteratorTest.cpp
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/Diagnostics.h
mlir/include/mlir/IR/DialectInterface.h
mlir/include/mlir/IR/TypeRange.h
mlir/include/mlir/IR/TypeUtilities.h
mlir/include/mlir/IR/UseDefLists.h
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/TypeUtilities.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index daa6d257dd000..6c2fb1029a918 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -307,6 +307,32 @@ auto map_range(ContainerTy &&C, FuncTy F) {
return make_range(map_iterator(C.begin(), F), map_iterator(C.end(), F));
}
+/// A base type of mapped iterator, that is useful for building derived
+/// iterators that do not need/want to store the map function (as in
+/// mapped_iterator). These iterators must simply provide a `mapElement` method
+/// that defines how to map a value of the iterator to the provided reference
+/// type.
+template <typename DerivedT, typename ItTy, typename ReferenceTy>
+class mapped_iterator_base
+ : public iterator_adaptor_base<
+ DerivedT, ItTy,
+ typename std::iterator_traits<ItTy>::iterator_category,
+ std::remove_reference_t<ReferenceTy>,
+ typename std::iterator_traits<ItTy>::
diff erence_type,
+ std::remove_reference_t<ReferenceTy> *, ReferenceTy> {
+public:
+ using BaseT = mapped_iterator_base<DerivedT, ItTy, ReferenceTy>;
+
+ mapped_iterator_base(ItTy U)
+ : mapped_iterator_base::iterator_adaptor_base(std::move(U)) {}
+
+ ItTy getCurrent() { return this->I; }
+
+ ReferenceTy operator*() const {
+ return static_cast<const DerivedT &>(*this).mapElement(*this->I);
+ }
+};
+
/// Helper to determine if type T has a member called rbegin().
template <typename Ty> class has_rbegin_impl {
using yes = char[1];
diff --git a/llvm/unittests/ADT/MappedIteratorTest.cpp b/llvm/unittests/ADT/MappedIteratorTest.cpp
index 61e45e3a66eb1..f94709805c2cd 100644
--- a/llvm/unittests/ADT/MappedIteratorTest.cpp
+++ b/llvm/unittests/ADT/MappedIteratorTest.cpp
@@ -47,4 +47,67 @@ TEST(MappedIteratorTest, FunctionPreservesReferences) {
EXPECT_EQ(M[1], 42) << "assignment should have modified M";
}
+TEST(MappedIteratorTest, CustomIteratorApplyFunctionOnDereference) {
+ struct CustomMapIterator
+ : public llvm::mapped_iterator_base<CustomMapIterator,
+ std::vector<int>::iterator, int> {
+ using BaseT::BaseT;
+
+ /// Map the element to the iterator result type.
+ int mapElement(int X) const { return X + 1; }
+ };
+
+ std::vector<int> V({0});
+
+ CustomMapIterator I(V.begin());
+
+ EXPECT_EQ(*I, 1) << "should have applied function in dereference";
+}
+
+TEST(MappedIteratorTest, CustomIteratorApplyFunctionOnArrow) {
+ struct S {
+ int Z = 0;
+ };
+ struct CustomMapIterator
+ : public llvm::mapped_iterator_base<CustomMapIterator,
+ std::vector<int>::iterator, S &> {
+ CustomMapIterator(std::vector<int>::iterator it, S *P) : BaseT(it), P(P) {}
+
+ /// Map the element to the iterator result type.
+ S &mapElement(int X) const { return *(P + X); }
+
+ S *P;
+ };
+
+ std::vector<int> V({0});
+ S Y;
+
+ CustomMapIterator I(V.begin(), &Y);
+
+ I->Z = 42;
+
+ EXPECT_EQ(Y.Z, 42) << "should have applied function during arrow";
+}
+
+TEST(MappedIteratorTest, CustomIteratorFunctionPreservesReferences) {
+ struct CustomMapIterator
+ : public llvm::mapped_iterator_base<CustomMapIterator,
+ std::vector<int>::iterator, int &> {
+ CustomMapIterator(std::vector<int>::iterator it, std::map<int, int> &M)
+ : BaseT(it), M(M) {}
+
+ /// Map the element to the iterator result type.
+ int &mapElement(int X) const { return M[X]; }
+
+ std::map<int, int> &M;
+ };
+ std::vector<int> V({1});
+ std::map<int, int> M({{1, 1}});
+
+ auto I = CustomMapIterator(V.begin(), M);
+ *I = 42;
+
+ EXPECT_EQ(M[1], 42) << "assignment should have modified M";
+}
+
} // anonymous namespace
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 6dabb6f3bf06c..db58d58d5453b 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -323,24 +323,46 @@ class DenseElementsAttr : public Attribute {
/// Iterator for walking over APFloat values.
class FloatElementIterator final
- : public llvm::mapped_iterator<IntElementIterator,
- std::function<APFloat(const APInt &)>> {
+ : public llvm::mapped_iterator_base<FloatElementIterator,
+ IntElementIterator, APFloat> {
+ public:
+ /// Map the element to the iterator result type.
+ APFloat mapElement(const APInt &value) const {
+ return APFloat(*smt, value);
+ }
+
+ private:
friend DenseElementsAttr;
/// Initializes the float element iterator to the specified iterator.
- FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it);
+ FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it)
+ : BaseT(it), smt(&smt) {}
+
+ /// The float semantics to use when constructing the APFloat.
+ const llvm::fltSemantics *smt;
};
/// Iterator for walking over complex APFloat values.
class ComplexFloatElementIterator final
- : public llvm::mapped_iterator<
- ComplexIntElementIterator,
- std::function<std::complex<APFloat>(const std::complex<APInt> &)>> {
+ : public llvm::mapped_iterator_base<ComplexFloatElementIterator,
+ ComplexIntElementIterator,
+ std::complex<APFloat>> {
+ public:
+ /// Map the element to the iterator result type.
+ std::complex<APFloat> mapElement(const std::complex<APInt> &value) const {
+ return {APFloat(*smt, value.real()), APFloat(*smt, value.imag())};
+ }
+
+ private:
friend DenseElementsAttr;
/// Initializes the float element iterator to the specified iterator.
ComplexFloatElementIterator(const llvm::fltSemantics &smt,
- ComplexIntElementIterator it);
+ ComplexIntElementIterator it)
+ : BaseT(it), smt(&smt) {}
+
+ /// The float semantics to use when constructing the APFloat.
+ const llvm::fltSemantics *smt;
};
//===--------------------------------------------------------------------===//
@@ -478,24 +500,27 @@ class DenseElementsAttr : public Attribute {
typename std::enable_if<std::is_base_of<Attribute, T>::value &&
!std::is_same<Attribute, T>::value>::type;
template <typename T>
- using DerivedAttributeElementIterator =
- llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
+ struct DerivedAttributeElementIterator
+ : public llvm::mapped_iterator_base<DerivedAttributeElementIterator<T>,
+ AttributeElementIterator, T> {
+ using DerivedAttributeElementIterator::BaseT::BaseT;
+
+ /// Map the element to the iterator result type.
+ T mapElement(Attribute attr) const { return attr.cast<T>(); }
+ };
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
iterator_range_impl<DerivedAttributeElementIterator<T>> getValues() const {
- auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
- return {Attribute::getType(),
- llvm::map_range(getValues<Attribute>(),
- static_cast<T (*)(Attribute)>(castFn))};
+ using DerivedIterT = DerivedAttributeElementIterator<T>;
+ return {Attribute::getType(), DerivedIterT(value_begin<Attribute>()),
+ DerivedIterT(value_end<Attribute>())};
}
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
DerivedAttributeElementIterator<T> value_begin() const {
- auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
- return {value_begin<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
+ return {value_begin<Attribute>()};
}
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
DerivedAttributeElementIterator<T> value_end() const {
- auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
- return {value_end<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
+ return {value_end<Attribute>()};
}
/// Return the held element values as a range of bool. The element type of
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index e8af2fe782f3d..ad1d22764165f 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -155,20 +155,6 @@ inline raw_ostream &operator<<(raw_ostream &os, const DiagnosticArgument &arg) {
class Diagnostic {
using NoteVector = std::vector<std::unique_ptr<Diagnostic>>;
- /// This class implements a wrapper iterator around NoteVector::iterator to
- /// implicitly dereference the unique_ptr.
- template <typename IteratorTy, typename NotePtrTy = decltype(*IteratorTy()),
- typename ResultTy = decltype(**IteratorTy())>
- class NoteIteratorImpl
- : public llvm::mapped_iterator<IteratorTy, ResultTy (*)(NotePtrTy)> {
- static ResultTy &unwrap(NotePtrTy note) { return *note; }
-
- public:
- NoteIteratorImpl(IteratorTy it)
- : llvm::mapped_iterator<IteratorTy, ResultTy (*)(NotePtrTy)>(it,
- &unwrap) {}
- };
-
public:
Diagnostic(Location loc, DiagnosticSeverity severity)
: loc(loc), severity(severity) {}
@@ -262,15 +248,16 @@ class Diagnostic {
/// diagnostic. Notes may not be attached to other notes.
Diagnostic &attachNote(Optional<Location> noteLoc = llvm::None);
- using note_iterator = NoteIteratorImpl<NoteVector::iterator>;
- using const_note_iterator = NoteIteratorImpl<NoteVector::const_iterator>;
+ using note_iterator = llvm::pointee_iterator<NoteVector::iterator>;
+ using const_note_iterator =
+ llvm::pointee_iterator<NoteVector::const_iterator>;
/// Returns the notes held by this diagnostic.
iterator_range<note_iterator> getNotes() {
- return {notes.begin(), notes.end()};
+ return llvm::make_pointee_range(notes);
}
iterator_range<const_note_iterator> getNotes() const {
- return {notes.begin(), notes.end()};
+ return llvm::make_pointee_range(notes);
}
/// Allow a diagnostic to be converted to 'failure'.
diff --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h
index da777e4aaa613..e58be464e9214 100644
--- a/mlir/include/mlir/IR/DialectInterface.h
+++ b/mlir/include/mlir/IR/DialectInterface.h
@@ -111,20 +111,16 @@ class DialectInterfaceCollectionBase {
/// An iterator class that iterates the held interface objects of the given
/// derived interface type.
template <typename InterfaceT>
- class iterator : public llvm::mapped_iterator<
- InterfaceVectorT::const_iterator,
- const InterfaceT &(*)(const DialectInterface *)> {
- static const InterfaceT &remapIt(const DialectInterface *interface) {
+ struct iterator
+ : public llvm::mapped_iterator_base<iterator<InterfaceT>,
+ InterfaceVectorT::const_iterator,
+ const InterfaceT &> {
+ using iterator::BaseT::BaseT;
+
+ /// Map the element to the iterator result type.
+ const InterfaceT &mapElement(const DialectInterface *interface) const {
return *static_cast<const InterfaceT *>(interface);
}
-
- iterator(InterfaceVectorT::const_iterator it)
- : llvm::mapped_iterator<
- InterfaceVectorT::const_iterator,
- const InterfaceT &(*)(const DialectInterface *)>(it, &remapIt) {}
-
- /// Allow access to the constructor.
- friend DialectInterfaceCollectionBase;
};
/// Iterator access to the held interfaces.
diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
index fa64f9a4abc7b..ad4ec67f19c24 100644
--- a/mlir/include/mlir/IR/TypeRange.h
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -124,16 +124,13 @@ class TypeRangeRange : public llvm::iterator_range<TypeRangeRangeIterator> {
/// This class implements iteration on the types of a given range of values.
template <typename ValueIteratorT>
class ValueTypeIterator final
- : public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)> {
- static Type unwrap(Value value) { return value.getType(); }
-
+ : public llvm::mapped_iterator_base<ValueTypeIterator<ValueIteratorT>,
+ ValueIteratorT, Type> {
public:
- /// Provide a const dereference method.
- Type operator*() const { return unwrap(*this->I); }
+ using ValueTypeIterator::BaseT::BaseT;
- /// Initializes the type iterator to the specified value iterator.
- ValueTypeIterator(ValueIteratorT it)
- : llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)>(it, &unwrap) {}
+ /// Map the element to the iterator result type.
+ Type mapElement(Value value) const { return value.getType(); }
};
/// This class implements iteration on the types of a given range of values.
diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h
index a04ae8e1a4fe7..04bf36e95feb0 100644
--- a/mlir/include/mlir/IR/TypeUtilities.h
+++ b/mlir/include/mlir/IR/TypeUtilities.h
@@ -66,36 +66,33 @@ LogicalResult verifyCompatibleShapes(TypeRange types);
/// Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyCompatibleDims(ArrayRef<int64_t> dims);
+
//===----------------------------------------------------------------------===//
// Utility Iterators
//===----------------------------------------------------------------------===//
// An iterator for the element types of an op's operands of shaped types.
class OperandElementTypeIterator final
- : public llvm::mapped_iterator<Operation::operand_iterator,
- Type (*)(Value)> {
+ : public llvm::mapped_iterator_base<OperandElementTypeIterator,
+ Operation::operand_iterator, Type> {
public:
- /// Initializes the result element type iterator to the specified operand
- /// iterator.
- explicit OperandElementTypeIterator(Operation::operand_iterator it);
+ using BaseT::BaseT;
-private:
- static Type unwrap(Value value);
+ /// Map the element to the iterator result type.
+ Type mapElement(Value value) const;
};
using OperandElementTypeRange = iterator_range<OperandElementTypeIterator>;
// An iterator for the tensor element types of an op's results of shaped types.
class ResultElementTypeIterator final
- : public llvm::mapped_iterator<Operation::result_iterator,
- Type (*)(Value)> {
+ : public llvm::mapped_iterator_base<ResultElementTypeIterator,
+ Operation::result_iterator, Type> {
public:
- /// Initializes the result element type iterator to the specified result
- /// iterator.
- explicit ResultElementTypeIterator(Operation::result_iterator it);
+ using BaseT::BaseT;
-private:
- static Type unwrap(Value value);
+ /// Map the element to the iterator result type.
+ Type mapElement(Value value) const;
};
using ResultElementTypeRange = iterator_range<ResultElementTypeIterator>;
diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h
index 1971a67409bfb..d27ee58d295b6 100644
--- a/mlir/include/mlir/IR/UseDefLists.h
+++ b/mlir/include/mlir/IR/UseDefLists.h
@@ -281,15 +281,16 @@ class ValueUseIterator
/// a specific use iterator.
template <typename UseIteratorT, typename OperandType>
class ValueUserIterator final
- : public llvm::mapped_iterator<UseIteratorT,
- Operation *(*)(OperandType &)> {
- static Operation *unwrap(OperandType &value) { return value.getOwner(); }
-
+ : public llvm::mapped_iterator_base<
+ ValueUserIterator<UseIteratorT, OperandType>, UseIteratorT,
+ Operation *> {
public:
- /// Initializes the user iterator to the specified use iterator.
- ValueUserIterator(UseIteratorT it)
- : llvm::mapped_iterator<UseIteratorT, Operation *(*)(OperandType &)>(
- it, &unwrap) {}
+ using ValueUserIterator::BaseT::BaseT;
+
+ /// Map the element to the iterator result type.
+ Operation *mapElement(OperandType &value) const { return value.getOwner(); }
+
+ /// Provide access to the underlying operation.
Operation *operator->() { return **this; }
};
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index b99e988eb2276..6b48e31369040 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -667,27 +667,6 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
readBits(getData(), offset + storageWidth, bitWidth)};
}
-//===----------------------------------------------------------------------===//
-// FloatElementIterator
-
-DenseElementsAttr::FloatElementIterator::FloatElementIterator(
- const llvm::fltSemantics &smt, IntElementIterator it)
- : llvm::mapped_iterator<IntElementIterator,
- std::function<APFloat(const APInt &)>>(
- it, [&](const APInt &val) { return APFloat(smt, val); }) {}
-
-//===----------------------------------------------------------------------===//
-// ComplexFloatElementIterator
-
-DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
- const llvm::fltSemantics &smt, ComplexIntElementIterator it)
- : llvm::mapped_iterator<
- ComplexIntElementIterator,
- std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
- it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
- return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
- }) {}
-
//===----------------------------------------------------------------------===//
// DenseElementsAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp
index bc6a0b9d9af3f..7533a60a444b1 100644
--- a/mlir/lib/IR/TypeUtilities.cpp
+++ b/mlir/lib/IR/TypeUtilities.cpp
@@ -151,20 +151,10 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
return success();
}
-OperandElementTypeIterator::OperandElementTypeIterator(
- Operation::operand_iterator it)
- : llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value)>(
- it, &unwrap) {}
-
-Type OperandElementTypeIterator::unwrap(Value value) {
+Type OperandElementTypeIterator::mapElement(Value value) const {
return value.getType().cast<ShapedType>().getElementType();
}
-ResultElementTypeIterator::ResultElementTypeIterator(
- Operation::result_iterator it)
- : llvm::mapped_iterator<Operation::result_iterator, Type (*)(Value)>(
- it, &unwrap) {}
-
-Type ResultElementTypeIterator::unwrap(Value value) {
+Type ResultElementTypeIterator::mapElement(Value value) const {
return value.getType().cast<ShapedType>().getElementType();
}
More information about the Mlir-commits
mailing list