[Mlir-commits] [mlir] 2125eb3 - [mlir][core] Slightly improved attribute lookup

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 4 14:49:02 PDT 2021


Author: Mogball
Date: 2021-11-04T21:48:58Z
New Revision: 2125eb3446d30c49f0997b5cd24535ded36528e0

URL: https://github.com/llvm/llvm-project/commit/2125eb3446d30c49f0997b5cd24535ded36528e0
DIFF: https://github.com/llvm/llvm-project/commit/2125eb3446d30c49f0997b5cd24535ded36528e0.diff

LOG: [mlir][core] Slightly improved attribute lookup

- String binary search does 1 less string comparison
- Identifier linear scan on large attribute list is switched to string binary search

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D112970

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/Operation.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/OperationSupport.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 51ac32d9a564..01af84c421e9 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -389,6 +389,10 @@ def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
     Optional<NamedAttribute> getNamed(StringRef name) const;
     Optional<NamedAttribute> getNamed(Identifier name) const;
 
+    /// Return whether the specified attribute is present.
+    bool contains(StringRef name) const;
+    bool contains(Identifier name) const;
+
     /// Support range iteration.
     using iterator = llvm::ArrayRef<NamedAttribute>::iterator;
     iterator begin() const;

diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 9862112f72d9..15f8e6b06146 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -332,8 +332,8 @@ class alignas(8) Operation final
 
   /// Return true if the operation has an attribute with the provided name,
   /// false otherwise.
-  bool hasAttr(Identifier name) { return static_cast<bool>(getAttr(name)); }
-  bool hasAttr(StringRef name) { return static_cast<bool>(getAttr(name)); }
+  bool hasAttr(Identifier name) { return attrs.contains(name); }
+  bool hasAttr(StringRef name) { return attrs.contains(name); }
   template <typename AttrClass, typename NameT>
   bool hasAttrOfType(NameT &&name) {
     return static_cast<bool>(

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index f56727e51cb0..ca3a1a61f85f 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -245,6 +245,63 @@ class AbstractOperation {
   ArrayRef<Identifier> attributeNames;
 };
 
+//===----------------------------------------------------------------------===//
+// Attribute Dictionary-Like Interface
+//===----------------------------------------------------------------------===//
+
+/// Attribute collections provide a dictionary-like interface. Define common
+/// lookup functions.
+namespace impl {
+
+/// Unsorted string search or identifier lookups are linear scans.
+template <typename IteratorT, typename NameT>
+std::pair<IteratorT, bool> findAttrUnsorted(IteratorT first, IteratorT last,
+                                            NameT name) {
+  for (auto it = first; it != last; ++it)
+    if (it->first == name)
+      return {it, true};
+  return {last, false};
+}
+
+/// Using llvm::lower_bound requires an extra string comparison to check whether
+/// the returned iterator points to the found element or whether it indicates
+/// the lower bound. Skip this redundant comparison by checking if `compare ==
+/// 0` during the binary search.
+template <typename IteratorT>
+std::pair<IteratorT, bool> findAttrSorted(IteratorT first, IteratorT last,
+                                          StringRef name) {
+  ptr
diff _t length = std::distance(first, last);
+
+  while (length > 0) {
+    ptr
diff _t half = length / 2;
+    IteratorT mid = first + half;
+    int compare = mid->first.strref().compare(name);
+    if (compare < 0) {
+      first = mid + 1;
+      length = length - half - 1;
+    } else if (compare > 0) {
+      length = half;
+    } else {
+      return {mid, true};
+    }
+  }
+  return {first, false};
+}
+
+/// Identifier lookups on large attribute lists will switch to string binary
+/// search. String binary searches become significantly faster than linear scans
+/// with the identifier when the attribute list becomes very large.
+template <typename IteratorT>
+std::pair<IteratorT, bool> findAttrSorted(IteratorT first, IteratorT last,
+                                          Identifier name) {
+  constexpr unsigned kSmallAttributeList = 16;
+  if (std::distance(first, last) > kSmallAttributeList)
+    return findAttrSorted(first, last, name.strref());
+  return findAttrUnsorted(first, last, name);
+}
+
+} // end namespace impl
+
 //===----------------------------------------------------------------------===//
 // NamedAttrList
 //===----------------------------------------------------------------------===//
@@ -253,9 +310,10 @@ class AbstractOperation {
 /// and does some basic work to remain sorted.
 class NamedAttrList {
 public:
+  using iterator = SmallVectorImpl<NamedAttribute>::iterator;
   using const_iterator = SmallVectorImpl<NamedAttribute>::const_iterator;
-  using const_reference = const NamedAttribute &;
   using reference = NamedAttribute &;
+  using const_reference = const NamedAttribute &;
   using size_type = size_t;
 
   NamedAttrList() : dictionarySorted({}, true) {}
@@ -346,6 +404,8 @@ class NamedAttrList {
   Attribute erase(Identifier name);
   Attribute erase(StringRef name);
 
+  iterator begin() { return attrs.begin(); }
+  iterator end() { return attrs.end(); }
   const_iterator begin() const { return attrs.begin(); }
   const_iterator end() const { return attrs.end(); }
 
@@ -359,6 +419,14 @@ class NamedAttrList {
   /// Erase the attribute at the given iterator position.
   Attribute eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it);
 
+  /// Lookup an attribute in the list.
+  template <typename AttrListT, typename NameT>
+  static auto findAttr(AttrListT &attrs, NameT name) {
+    return attrs.isSorted()
+               ? impl::findAttrSorted(attrs.begin(), attrs.end(), name)
+               : impl::findAttrUnsorted(attrs.begin(), attrs.end(), name);
+  }
+
   // These are marked mutable as they may be modified (e.g., sorted)
   mutable SmallVector<NamedAttribute, 4> attrs;
   // Pair with cached DictionaryAttr and status of whether attrs is sorted.

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 2acc386d259e..41a8c46c0c6d 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -185,26 +185,30 @@ DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
 
 /// Return the specified attribute if present, null otherwise.
 Attribute DictionaryAttr::get(StringRef name) const {
-  Optional<NamedAttribute> attr = getNamed(name);
-  return attr ? attr->second : nullptr;
+  auto it = impl::findAttrSorted(begin(), end(), name);
+  return it.second ? it.first->second : Attribute();
 }
 Attribute DictionaryAttr::get(Identifier name) const {
-  Optional<NamedAttribute> attr = getNamed(name);
-  return attr ? attr->second : nullptr;
+  auto it = impl::findAttrSorted(begin(), end(), name);
+  return it.second ? it.first->second : Attribute();
 }
 
 /// Return the specified named attribute if present, None otherwise.
 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
-  ArrayRef<NamedAttribute> values = getValue();
-  const auto *it = llvm::lower_bound(values, name);
-  return it != values.end() && it->first == name ? *it
-                                                 : Optional<NamedAttribute>();
+  auto it = impl::findAttrSorted(begin(), end(), name);
+  return it.second ? *it.first : Optional<NamedAttribute>();
 }
 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
-  for (auto elt : getValue())
-    if (elt.first == name)
-      return elt;
-  return llvm::None;
+  auto it = impl::findAttrSorted(begin(), end(), name);
+  return it.second ? *it.first : Optional<NamedAttribute>();
+}
+
+/// Return whether the specified attribute is present.
+bool DictionaryAttr::contains(StringRef name) const {
+  return impl::findAttrSorted(begin(), end(), name).second;
+}
+bool DictionaryAttr::contains(Identifier name) const {
+  return impl::findAttrSorted(begin(), end(), name).second;
 }
 
 DictionaryAttr::iterator DictionaryAttr::begin() const {

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 1202e0f6275d..4c9bc848ce47 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -81,42 +81,24 @@ void NamedAttrList::push_back(NamedAttribute newAttribute) {
   attrs.push_back(newAttribute);
 }
 
-/// Helper function to find attribute in possible sorted vector of
-/// NamedAttributes.
-template <typename T>
-static auto *findAttr(SmallVectorImpl<NamedAttribute> &attrs, T name,
-                      bool sorted) {
-  if (!sorted) {
-    return llvm::find_if(
-        attrs, [name](NamedAttribute attr) { return attr.first == name; });
-  }
-
-  auto *it = llvm::lower_bound(attrs, name);
-  if (it == attrs.end() || it->first != name)
-    return attrs.end();
-  return it;
-}
-
 /// Return the specified attribute if present, null otherwise.
 Attribute NamedAttrList::get(StringRef name) const {
-  auto *it = findAttr(attrs, name, isSorted());
-  return it != attrs.end() ? it->second : nullptr;
+  auto it = findAttr(*this, name);
+  return it.second ? it.first->second : Attribute();
 }
-
-/// Return the specified attribute if present, null otherwise.
 Attribute NamedAttrList::get(Identifier name) const {
-  auto *it = findAttr(attrs, name, isSorted());
-  return it != attrs.end() ? it->second : nullptr;
+  auto it = findAttr(*this, name);
+  return it.second ? it.first->second : Attribute();
 }
 
 /// Return the specified named attribute if present, None otherwise.
 Optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
-  auto *it = findAttr(attrs, name, isSorted());
-  return it != attrs.end() ? *it : Optional<NamedAttribute>();
+  auto it = findAttr(*this, name);
+  return it.second ? *it.first : Optional<NamedAttribute>();
 }
 Optional<NamedAttribute> NamedAttrList::getNamed(Identifier name) const {
-  auto *it = findAttr(attrs, name, isSorted());
-  return it != attrs.end() ? *it : Optional<NamedAttribute>();
+  auto it = findAttr(*this, name);
+  return it.second ? *it.first : Optional<NamedAttribute>();
 }
 
 /// If the an attribute exists with the specified name, change it to the new
@@ -124,34 +106,36 @@ Optional<NamedAttribute> NamedAttrList::getNamed(Identifier name) const {
 Attribute NamedAttrList::set(Identifier name, Attribute value) {
   assert(value && "attributes may never be null");
 
-  // Look for an existing value for the given name, and set it in-place.
-  auto *it = findAttr(attrs, name, isSorted());
-  if (it != attrs.end()) {
-    // Only update if the value is 
diff erent from the existing.
-    Attribute oldValue = it->second;
-    if (oldValue != value) {
+  // Look for an existing attribute with the given name, and set its value
+  // in-place. Return the previous value of the attribute, if there was one.
+  auto it = findAttr(*this, name);
+  if (it.second) {
+    // Update the existing attribute by swapping out the old value for the new
+    // value. Return the old value.
+    if (it.first->second != value) {
+      std::swap(it.first->second, value);
+      // If the attributes have changed, the dictionary is invalidated.
       dictionarySorted.setPointer(nullptr);
-      it->second = value;
     }
-    return oldValue;
+    return value;
   }
-
-  // Otherwise, insert the new attribute into its sorted position.
-  it = llvm::lower_bound(attrs, name);
+  // Perform a string lookup to insert the new attribute into its sorted
+  // position.
+  if (isSorted())
+    it = findAttr(*this, name.strref());
+  attrs.insert(it.first, {name, value});
+  // Invalidate the dictionary. Return null as there was no previous value.
   dictionarySorted.setPointer(nullptr);
-  attrs.insert(it, {name, value});
   return Attribute();
 }
+
 Attribute NamedAttrList::set(StringRef name, Attribute value) {
-  assert(value && "setting null attribute not supported");
+  assert(value && "attributes may never be null");
   return set(mlir::Identifier::get(name, value.getContext()), value);
 }
 
 Attribute
 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
-  if (it == attrs.end())
-    return nullptr;
-
   // Erasing does not affect the sorted property.
   Attribute attr = it->second;
   attrs.erase(it);
@@ -160,11 +144,13 @@ NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
 }
 
 Attribute NamedAttrList::erase(Identifier name) {
-  return eraseImpl(findAttr(attrs, name, isSorted()));
+  auto it = findAttr(*this, name);
+  return it.second ? eraseImpl(it.first) : Attribute();
 }
 
 Attribute NamedAttrList::erase(StringRef name) {
-  return eraseImpl(findAttr(attrs, name, isSorted()));
+  auto it = findAttr(*this, name);
+  return it.second ? eraseImpl(it.first) : Attribute();
 }
 
 NamedAttrList &


        


More information about the Mlir-commits mailing list