[Mlir-commits] [mlir] 6de6131 - [mlir] Optimize usage of llvm::mapped_iterator

River Riddle llvmlistbot at llvm.org
Wed Nov 10 19:44:54 PST 2021


Author: River Riddle
Date: 2021-11-11T03:26:29Z
New Revision: 6de6131f029d597188b05e666fb77fb69e5b936c

URL: https://github.com/llvm/llvm-project/commit/6de6131f029d597188b05e666fb77fb69e5b936c
DIFF: https://github.com/llvm/llvm-project/commit/6de6131f029d597188b05e666fb77fb69e5b936c.diff

LOG: [mlir] Optimize usage of llvm::mapped_iterator

mapped_iterator is a useful abstraction for applying a
map function over an existing iterator, but our current
usage ends up allocating storage/making indirect calls
even with the map function is a known function, which
is horribly inefficient. This commit refactors the usage
of mapped_iterator to avoid this, and allows for directly
referencing the map function when dereferencing.

Fixes PR52319

Differential Revision: https://reviews.llvm.org/D113511

Added: 
    

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    llvm/unittests/ADT/MappedIteratorTest.cpp
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/Diagnostics.h
    mlir/include/mlir/IR/DialectInterface.h
    mlir/include/mlir/IR/TypeRange.h
    mlir/include/mlir/IR/TypeUtilities.h
    mlir/include/mlir/IR/UseDefLists.h
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/TypeUtilities.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index daa6d257dd000..6c2fb1029a918 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -307,6 +307,32 @@ auto map_range(ContainerTy &&C, FuncTy F) {
   return make_range(map_iterator(C.begin(), F), map_iterator(C.end(), F));
 }
 
+/// A base type of mapped iterator, that is useful for building derived
+/// iterators that do not need/want to store the map function (as in
+/// mapped_iterator). These iterators must simply provide a `mapElement` method
+/// that defines how to map a value of the iterator to the provided reference
+/// type.
+template <typename DerivedT, typename ItTy, typename ReferenceTy>
+class mapped_iterator_base
+    : public iterator_adaptor_base<
+          DerivedT, ItTy,
+          typename std::iterator_traits<ItTy>::iterator_category,
+          std::remove_reference_t<ReferenceTy>,
+          typename std::iterator_traits<ItTy>::
diff erence_type,
+          std::remove_reference_t<ReferenceTy> *, ReferenceTy> {
+public:
+  using BaseT = mapped_iterator_base<DerivedT, ItTy, ReferenceTy>;
+
+  mapped_iterator_base(ItTy U)
+      : mapped_iterator_base::iterator_adaptor_base(std::move(U)) {}
+
+  ItTy getCurrent() { return this->I; }
+
+  ReferenceTy operator*() const {
+    return static_cast<const DerivedT &>(*this).mapElement(*this->I);
+  }
+};
+
 /// Helper to determine if type T has a member called rbegin().
 template <typename Ty> class has_rbegin_impl {
   using yes = char[1];

diff  --git a/llvm/unittests/ADT/MappedIteratorTest.cpp b/llvm/unittests/ADT/MappedIteratorTest.cpp
index 61e45e3a66eb1..f94709805c2cd 100644
--- a/llvm/unittests/ADT/MappedIteratorTest.cpp
+++ b/llvm/unittests/ADT/MappedIteratorTest.cpp
@@ -47,4 +47,67 @@ TEST(MappedIteratorTest, FunctionPreservesReferences) {
   EXPECT_EQ(M[1], 42) << "assignment should have modified M";
 }
 
+TEST(MappedIteratorTest, CustomIteratorApplyFunctionOnDereference) {
+  struct CustomMapIterator
+      : public llvm::mapped_iterator_base<CustomMapIterator,
+                                          std::vector<int>::iterator, int> {
+    using BaseT::BaseT;
+
+    /// Map the element to the iterator result type.
+    int mapElement(int X) const { return X + 1; }
+  };
+
+  std::vector<int> V({0});
+
+  CustomMapIterator I(V.begin());
+
+  EXPECT_EQ(*I, 1) << "should have applied function in dereference";
+}
+
+TEST(MappedIteratorTest, CustomIteratorApplyFunctionOnArrow) {
+  struct S {
+    int Z = 0;
+  };
+  struct CustomMapIterator
+      : public llvm::mapped_iterator_base<CustomMapIterator,
+                                          std::vector<int>::iterator, S &> {
+    CustomMapIterator(std::vector<int>::iterator it, S *P) : BaseT(it), P(P) {}
+
+    /// Map the element to the iterator result type.
+    S &mapElement(int X) const { return *(P + X); }
+
+    S *P;
+  };
+
+  std::vector<int> V({0});
+  S Y;
+
+  CustomMapIterator I(V.begin(), &Y);
+
+  I->Z = 42;
+
+  EXPECT_EQ(Y.Z, 42) << "should have applied function during arrow";
+}
+
+TEST(MappedIteratorTest, CustomIteratorFunctionPreservesReferences) {
+  struct CustomMapIterator
+      : public llvm::mapped_iterator_base<CustomMapIterator,
+                                          std::vector<int>::iterator, int &> {
+    CustomMapIterator(std::vector<int>::iterator it, std::map<int, int> &M)
+        : BaseT(it), M(M) {}
+
+    /// Map the element to the iterator result type.
+    int &mapElement(int X) const { return M[X]; }
+
+    std::map<int, int> &M;
+  };
+  std::vector<int> V({1});
+  std::map<int, int> M({{1, 1}});
+
+  auto I = CustomMapIterator(V.begin(), M);
+  *I = 42;
+
+  EXPECT_EQ(M[1], 42) << "assignment should have modified M";
+}
+
 } // anonymous namespace

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 6dabb6f3bf06c..db58d58d5453b 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -323,24 +323,46 @@ class DenseElementsAttr : public Attribute {
 
   /// Iterator for walking over APFloat values.
   class FloatElementIterator final
-      : public llvm::mapped_iterator<IntElementIterator,
-                                     std::function<APFloat(const APInt &)>> {
+      : public llvm::mapped_iterator_base<FloatElementIterator,
+                                          IntElementIterator, APFloat> {
+  public:
+    /// Map the element to the iterator result type.
+    APFloat mapElement(const APInt &value) const {
+      return APFloat(*smt, value);
+    }
+
+  private:
     friend DenseElementsAttr;
 
     /// Initializes the float element iterator to the specified iterator.
-    FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it);
+    FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it)
+        : BaseT(it), smt(&smt) {}
+
+    /// The float semantics to use when constructing the APFloat.
+    const llvm::fltSemantics *smt;
   };
 
   /// Iterator for walking over complex APFloat values.
   class ComplexFloatElementIterator final
-      : public llvm::mapped_iterator<
-            ComplexIntElementIterator,
-            std::function<std::complex<APFloat>(const std::complex<APInt> &)>> {
+      : public llvm::mapped_iterator_base<ComplexFloatElementIterator,
+                                          ComplexIntElementIterator,
+                                          std::complex<APFloat>> {
+  public:
+    /// Map the element to the iterator result type.
+    std::complex<APFloat> mapElement(const std::complex<APInt> &value) const {
+      return {APFloat(*smt, value.real()), APFloat(*smt, value.imag())};
+    }
+
+  private:
     friend DenseElementsAttr;
 
     /// Initializes the float element iterator to the specified iterator.
     ComplexFloatElementIterator(const llvm::fltSemantics &smt,
-                                ComplexIntElementIterator it);
+                                ComplexIntElementIterator it)
+        : BaseT(it), smt(&smt) {}
+
+    /// The float semantics to use when constructing the APFloat.
+    const llvm::fltSemantics *smt;
   };
 
   //===--------------------------------------------------------------------===//
@@ -478,24 +500,27 @@ class DenseElementsAttr : public Attribute {
       typename std::enable_if<std::is_base_of<Attribute, T>::value &&
                               !std::is_same<Attribute, T>::value>::type;
   template <typename T>
-  using DerivedAttributeElementIterator =
-      llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
+  struct DerivedAttributeElementIterator
+      : public llvm::mapped_iterator_base<DerivedAttributeElementIterator<T>,
+                                          AttributeElementIterator, T> {
+    using DerivedAttributeElementIterator::BaseT::BaseT;
+
+    /// Map the element to the iterator result type.
+    T mapElement(Attribute attr) const { return attr.cast<T>(); }
+  };
   template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
   iterator_range_impl<DerivedAttributeElementIterator<T>> getValues() const {
-    auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
-    return {Attribute::getType(),
-            llvm::map_range(getValues<Attribute>(),
-                            static_cast<T (*)(Attribute)>(castFn))};
+    using DerivedIterT = DerivedAttributeElementIterator<T>;
+    return {Attribute::getType(), DerivedIterT(value_begin<Attribute>()),
+            DerivedIterT(value_end<Attribute>())};
   }
   template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
   DerivedAttributeElementIterator<T> value_begin() const {
-    auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
-    return {value_begin<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
+    return {value_begin<Attribute>()};
   }
   template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
   DerivedAttributeElementIterator<T> value_end() const {
-    auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
-    return {value_end<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
+    return {value_end<Attribute>()};
   }
 
   /// Return the held element values as a range of bool. The element type of

diff  --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index e8af2fe782f3d..ad1d22764165f 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -155,20 +155,6 @@ inline raw_ostream &operator<<(raw_ostream &os, const DiagnosticArgument &arg) {
 class Diagnostic {
   using NoteVector = std::vector<std::unique_ptr<Diagnostic>>;
 
-  /// This class implements a wrapper iterator around NoteVector::iterator to
-  /// implicitly dereference the unique_ptr.
-  template <typename IteratorTy, typename NotePtrTy = decltype(*IteratorTy()),
-            typename ResultTy = decltype(**IteratorTy())>
-  class NoteIteratorImpl
-      : public llvm::mapped_iterator<IteratorTy, ResultTy (*)(NotePtrTy)> {
-    static ResultTy &unwrap(NotePtrTy note) { return *note; }
-
-  public:
-    NoteIteratorImpl(IteratorTy it)
-        : llvm::mapped_iterator<IteratorTy, ResultTy (*)(NotePtrTy)>(it,
-                                                                     &unwrap) {}
-  };
-
 public:
   Diagnostic(Location loc, DiagnosticSeverity severity)
       : loc(loc), severity(severity) {}
@@ -262,15 +248,16 @@ class Diagnostic {
   /// diagnostic. Notes may not be attached to other notes.
   Diagnostic &attachNote(Optional<Location> noteLoc = llvm::None);
 
-  using note_iterator = NoteIteratorImpl<NoteVector::iterator>;
-  using const_note_iterator = NoteIteratorImpl<NoteVector::const_iterator>;
+  using note_iterator = llvm::pointee_iterator<NoteVector::iterator>;
+  using const_note_iterator =
+      llvm::pointee_iterator<NoteVector::const_iterator>;
 
   /// Returns the notes held by this diagnostic.
   iterator_range<note_iterator> getNotes() {
-    return {notes.begin(), notes.end()};
+    return llvm::make_pointee_range(notes);
   }
   iterator_range<const_note_iterator> getNotes() const {
-    return {notes.begin(), notes.end()};
+    return llvm::make_pointee_range(notes);
   }
 
   /// Allow a diagnostic to be converted to 'failure'.

diff  --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h
index da777e4aaa613..e58be464e9214 100644
--- a/mlir/include/mlir/IR/DialectInterface.h
+++ b/mlir/include/mlir/IR/DialectInterface.h
@@ -111,20 +111,16 @@ class DialectInterfaceCollectionBase {
   /// An iterator class that iterates the held interface objects of the given
   /// derived interface type.
   template <typename InterfaceT>
-  class iterator : public llvm::mapped_iterator<
-                       InterfaceVectorT::const_iterator,
-                       const InterfaceT &(*)(const DialectInterface *)> {
-    static const InterfaceT &remapIt(const DialectInterface *interface) {
+  struct iterator
+      : public llvm::mapped_iterator_base<iterator<InterfaceT>,
+                                          InterfaceVectorT::const_iterator,
+                                          const InterfaceT &> {
+    using iterator::BaseT::BaseT;
+
+    /// Map the element to the iterator result type.
+    const InterfaceT &mapElement(const DialectInterface *interface) const {
       return *static_cast<const InterfaceT *>(interface);
     }
-
-    iterator(InterfaceVectorT::const_iterator it)
-        : llvm::mapped_iterator<
-              InterfaceVectorT::const_iterator,
-              const InterfaceT &(*)(const DialectInterface *)>(it, &remapIt) {}
-
-    /// Allow access to the constructor.
-    friend DialectInterfaceCollectionBase;
   };
 
   /// Iterator access to the held interfaces.

diff  --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
index fa64f9a4abc7b..ad4ec67f19c24 100644
--- a/mlir/include/mlir/IR/TypeRange.h
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -124,16 +124,13 @@ class TypeRangeRange : public llvm::iterator_range<TypeRangeRangeIterator> {
 /// This class implements iteration on the types of a given range of values.
 template <typename ValueIteratorT>
 class ValueTypeIterator final
-    : public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)> {
-  static Type unwrap(Value value) { return value.getType(); }
-
+    : public llvm::mapped_iterator_base<ValueTypeIterator<ValueIteratorT>,
+                                        ValueIteratorT, Type> {
 public:
-  /// Provide a const dereference method.
-  Type operator*() const { return unwrap(*this->I); }
+  using ValueTypeIterator::BaseT::BaseT;
 
-  /// Initializes the type iterator to the specified value iterator.
-  ValueTypeIterator(ValueIteratorT it)
-      : llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)>(it, &unwrap) {}
+  /// Map the element to the iterator result type.
+  Type mapElement(Value value) const { return value.getType(); }
 };
 
 /// This class implements iteration on the types of a given range of values.

diff  --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h
index a04ae8e1a4fe7..04bf36e95feb0 100644
--- a/mlir/include/mlir/IR/TypeUtilities.h
+++ b/mlir/include/mlir/IR/TypeUtilities.h
@@ -66,36 +66,33 @@ LogicalResult verifyCompatibleShapes(TypeRange types);
 
 /// Dimensions are compatible if all non-dynamic dims are equal.
 LogicalResult verifyCompatibleDims(ArrayRef<int64_t> dims);
+
 //===----------------------------------------------------------------------===//
 // Utility Iterators
 //===----------------------------------------------------------------------===//
 
 // An iterator for the element types of an op's operands of shaped types.
 class OperandElementTypeIterator final
-    : public llvm::mapped_iterator<Operation::operand_iterator,
-                                   Type (*)(Value)> {
+    : public llvm::mapped_iterator_base<OperandElementTypeIterator,
+                                        Operation::operand_iterator, Type> {
 public:
-  /// Initializes the result element type iterator to the specified operand
-  /// iterator.
-  explicit OperandElementTypeIterator(Operation::operand_iterator it);
+  using BaseT::BaseT;
 
-private:
-  static Type unwrap(Value value);
+  /// Map the element to the iterator result type.
+  Type mapElement(Value value) const;
 };
 
 using OperandElementTypeRange = iterator_range<OperandElementTypeIterator>;
 
 // An iterator for the tensor element types of an op's results of shaped types.
 class ResultElementTypeIterator final
-    : public llvm::mapped_iterator<Operation::result_iterator,
-                                   Type (*)(Value)> {
+    : public llvm::mapped_iterator_base<ResultElementTypeIterator,
+                                        Operation::result_iterator, Type> {
 public:
-  /// Initializes the result element type iterator to the specified result
-  /// iterator.
-  explicit ResultElementTypeIterator(Operation::result_iterator it);
+  using BaseT::BaseT;
 
-private:
-  static Type unwrap(Value value);
+  /// Map the element to the iterator result type.
+  Type mapElement(Value value) const;
 };
 
 using ResultElementTypeRange = iterator_range<ResultElementTypeIterator>;

diff  --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h
index 1971a67409bfb..d27ee58d295b6 100644
--- a/mlir/include/mlir/IR/UseDefLists.h
+++ b/mlir/include/mlir/IR/UseDefLists.h
@@ -281,15 +281,16 @@ class ValueUseIterator
 /// a specific use iterator.
 template <typename UseIteratorT, typename OperandType>
 class ValueUserIterator final
-    : public llvm::mapped_iterator<UseIteratorT,
-                                   Operation *(*)(OperandType &)> {
-  static Operation *unwrap(OperandType &value) { return value.getOwner(); }
-
+    : public llvm::mapped_iterator_base<
+          ValueUserIterator<UseIteratorT, OperandType>, UseIteratorT,
+          Operation *> {
 public:
-  /// Initializes the user iterator to the specified use iterator.
-  ValueUserIterator(UseIteratorT it)
-      : llvm::mapped_iterator<UseIteratorT, Operation *(*)(OperandType &)>(
-            it, &unwrap) {}
+  using ValueUserIterator::BaseT::BaseT;
+
+  /// Map the element to the iterator result type.
+  Operation *mapElement(OperandType &value) const { return value.getOwner(); }
+
+  /// Provide access to the underlying operation.
   Operation *operator->() { return **this; }
 };
 

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index b99e988eb2276..6b48e31369040 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -667,27 +667,6 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
           readBits(getData(), offset + storageWidth, bitWidth)};
 }
 
-//===----------------------------------------------------------------------===//
-// FloatElementIterator
-
-DenseElementsAttr::FloatElementIterator::FloatElementIterator(
-    const llvm::fltSemantics &smt, IntElementIterator it)
-    : llvm::mapped_iterator<IntElementIterator,
-                            std::function<APFloat(const APInt &)>>(
-          it, [&](const APInt &val) { return APFloat(smt, val); }) {}
-
-//===----------------------------------------------------------------------===//
-// ComplexFloatElementIterator
-
-DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
-    const llvm::fltSemantics &smt, ComplexIntElementIterator it)
-    : llvm::mapped_iterator<
-          ComplexIntElementIterator,
-          std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
-          it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
-            return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
-          }) {}
-
 //===----------------------------------------------------------------------===//
 // DenseElementsAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp
index bc6a0b9d9af3f..7533a60a444b1 100644
--- a/mlir/lib/IR/TypeUtilities.cpp
+++ b/mlir/lib/IR/TypeUtilities.cpp
@@ -151,20 +151,10 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
   return success();
 }
 
-OperandElementTypeIterator::OperandElementTypeIterator(
-    Operation::operand_iterator it)
-    : llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value)>(
-          it, &unwrap) {}
-
-Type OperandElementTypeIterator::unwrap(Value value) {
+Type OperandElementTypeIterator::mapElement(Value value) const {
   return value.getType().cast<ShapedType>().getElementType();
 }
 
-ResultElementTypeIterator::ResultElementTypeIterator(
-    Operation::result_iterator it)
-    : llvm::mapped_iterator<Operation::result_iterator, Type (*)(Value)>(
-          it, &unwrap) {}
-
-Type ResultElementTypeIterator::unwrap(Value value) {
+Type ResultElementTypeIterator::mapElement(Value value) const {
   return value.getType().cast<ShapedType>().getElementType();
 }


        


More information about the Mlir-commits mailing list