[Mlir-commits] [mlir] 2125eb3 - [mlir][core] Slightly improved attribute lookup
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 4 14:49:02 PDT 2021
Author: Mogball
Date: 2021-11-04T21:48:58Z
New Revision: 2125eb3446d30c49f0997b5cd24535ded36528e0
URL: https://github.com/llvm/llvm-project/commit/2125eb3446d30c49f0997b5cd24535ded36528e0
DIFF: https://github.com/llvm/llvm-project/commit/2125eb3446d30c49f0997b5cd24535ded36528e0.diff
LOG: [mlir][core] Slightly improved attribute lookup
- String binary search does 1 less string comparison
- Identifier linear scan on large attribute list is switched to string binary search
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D112970
Added:
Modified:
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/OperationSupport.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 51ac32d9a564..01af84c421e9 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -389,6 +389,10 @@ def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
Optional<NamedAttribute> getNamed(StringRef name) const;
Optional<NamedAttribute> getNamed(Identifier name) const;
+ /// Return whether the specified attribute is present.
+ bool contains(StringRef name) const;
+ bool contains(Identifier name) const;
+
/// Support range iteration.
using iterator = llvm::ArrayRef<NamedAttribute>::iterator;
iterator begin() const;
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 9862112f72d9..15f8e6b06146 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -332,8 +332,8 @@ class alignas(8) Operation final
/// Return true if the operation has an attribute with the provided name,
/// false otherwise.
- bool hasAttr(Identifier name) { return static_cast<bool>(getAttr(name)); }
- bool hasAttr(StringRef name) { return static_cast<bool>(getAttr(name)); }
+ bool hasAttr(Identifier name) { return attrs.contains(name); }
+ bool hasAttr(StringRef name) { return attrs.contains(name); }
template <typename AttrClass, typename NameT>
bool hasAttrOfType(NameT &&name) {
return static_cast<bool>(
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index f56727e51cb0..ca3a1a61f85f 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -245,6 +245,63 @@ class AbstractOperation {
ArrayRef<Identifier> attributeNames;
};
+//===----------------------------------------------------------------------===//
+// Attribute Dictionary-Like Interface
+//===----------------------------------------------------------------------===//
+
+/// Attribute collections provide a dictionary-like interface. Define common
+/// lookup functions.
+namespace impl {
+
+/// Unsorted string search or identifier lookups are linear scans.
+template <typename IteratorT, typename NameT>
+std::pair<IteratorT, bool> findAttrUnsorted(IteratorT first, IteratorT last,
+ NameT name) {
+ for (auto it = first; it != last; ++it)
+ if (it->first == name)
+ return {it, true};
+ return {last, false};
+}
+
+/// Using llvm::lower_bound requires an extra string comparison to check whether
+/// the returned iterator points to the found element or whether it indicates
+/// the lower bound. Skip this redundant comparison by checking if `compare ==
+/// 0` during the binary search.
+template <typename IteratorT>
+std::pair<IteratorT, bool> findAttrSorted(IteratorT first, IteratorT last,
+ StringRef name) {
+ ptr
diff _t length = std::distance(first, last);
+
+ while (length > 0) {
+ ptr
diff _t half = length / 2;
+ IteratorT mid = first + half;
+ int compare = mid->first.strref().compare(name);
+ if (compare < 0) {
+ first = mid + 1;
+ length = length - half - 1;
+ } else if (compare > 0) {
+ length = half;
+ } else {
+ return {mid, true};
+ }
+ }
+ return {first, false};
+}
+
+/// Identifier lookups on large attribute lists will switch to string binary
+/// search. String binary searches become significantly faster than linear scans
+/// with the identifier when the attribute list becomes very large.
+template <typename IteratorT>
+std::pair<IteratorT, bool> findAttrSorted(IteratorT first, IteratorT last,
+ Identifier name) {
+ constexpr unsigned kSmallAttributeList = 16;
+ if (std::distance(first, last) > kSmallAttributeList)
+ return findAttrSorted(first, last, name.strref());
+ return findAttrUnsorted(first, last, name);
+}
+
+} // end namespace impl
+
//===----------------------------------------------------------------------===//
// NamedAttrList
//===----------------------------------------------------------------------===//
@@ -253,9 +310,10 @@ class AbstractOperation {
/// and does some basic work to remain sorted.
class NamedAttrList {
public:
+ using iterator = SmallVectorImpl<NamedAttribute>::iterator;
using const_iterator = SmallVectorImpl<NamedAttribute>::const_iterator;
- using const_reference = const NamedAttribute &;
using reference = NamedAttribute &;
+ using const_reference = const NamedAttribute &;
using size_type = size_t;
NamedAttrList() : dictionarySorted({}, true) {}
@@ -346,6 +404,8 @@ class NamedAttrList {
Attribute erase(Identifier name);
Attribute erase(StringRef name);
+ iterator begin() { return attrs.begin(); }
+ iterator end() { return attrs.end(); }
const_iterator begin() const { return attrs.begin(); }
const_iterator end() const { return attrs.end(); }
@@ -359,6 +419,14 @@ class NamedAttrList {
/// Erase the attribute at the given iterator position.
Attribute eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it);
+ /// Lookup an attribute in the list.
+ template <typename AttrListT, typename NameT>
+ static auto findAttr(AttrListT &attrs, NameT name) {
+ return attrs.isSorted()
+ ? impl::findAttrSorted(attrs.begin(), attrs.end(), name)
+ : impl::findAttrUnsorted(attrs.begin(), attrs.end(), name);
+ }
+
// These are marked mutable as they may be modified (e.g., sorted)
mutable SmallVector<NamedAttribute, 4> attrs;
// Pair with cached DictionaryAttr and status of whether attrs is sorted.
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 2acc386d259e..41a8c46c0c6d 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -185,26 +185,30 @@ DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
/// Return the specified attribute if present, null otherwise.
Attribute DictionaryAttr::get(StringRef name) const {
- Optional<NamedAttribute> attr = getNamed(name);
- return attr ? attr->second : nullptr;
+ auto it = impl::findAttrSorted(begin(), end(), name);
+ return it.second ? it.first->second : Attribute();
}
Attribute DictionaryAttr::get(Identifier name) const {
- Optional<NamedAttribute> attr = getNamed(name);
- return attr ? attr->second : nullptr;
+ auto it = impl::findAttrSorted(begin(), end(), name);
+ return it.second ? it.first->second : Attribute();
}
/// Return the specified named attribute if present, None otherwise.
Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
- ArrayRef<NamedAttribute> values = getValue();
- const auto *it = llvm::lower_bound(values, name);
- return it != values.end() && it->first == name ? *it
- : Optional<NamedAttribute>();
+ auto it = impl::findAttrSorted(begin(), end(), name);
+ return it.second ? *it.first : Optional<NamedAttribute>();
}
Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
- for (auto elt : getValue())
- if (elt.first == name)
- return elt;
- return llvm::None;
+ auto it = impl::findAttrSorted(begin(), end(), name);
+ return it.second ? *it.first : Optional<NamedAttribute>();
+}
+
+/// Return whether the specified attribute is present.
+bool DictionaryAttr::contains(StringRef name) const {
+ return impl::findAttrSorted(begin(), end(), name).second;
+}
+bool DictionaryAttr::contains(Identifier name) const {
+ return impl::findAttrSorted(begin(), end(), name).second;
}
DictionaryAttr::iterator DictionaryAttr::begin() const {
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 1202e0f6275d..4c9bc848ce47 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -81,42 +81,24 @@ void NamedAttrList::push_back(NamedAttribute newAttribute) {
attrs.push_back(newAttribute);
}
-/// Helper function to find attribute in possible sorted vector of
-/// NamedAttributes.
-template <typename T>
-static auto *findAttr(SmallVectorImpl<NamedAttribute> &attrs, T name,
- bool sorted) {
- if (!sorted) {
- return llvm::find_if(
- attrs, [name](NamedAttribute attr) { return attr.first == name; });
- }
-
- auto *it = llvm::lower_bound(attrs, name);
- if (it == attrs.end() || it->first != name)
- return attrs.end();
- return it;
-}
-
/// Return the specified attribute if present, null otherwise.
Attribute NamedAttrList::get(StringRef name) const {
- auto *it = findAttr(attrs, name, isSorted());
- return it != attrs.end() ? it->second : nullptr;
+ auto it = findAttr(*this, name);
+ return it.second ? it.first->second : Attribute();
}
-
-/// Return the specified attribute if present, null otherwise.
Attribute NamedAttrList::get(Identifier name) const {
- auto *it = findAttr(attrs, name, isSorted());
- return it != attrs.end() ? it->second : nullptr;
+ auto it = findAttr(*this, name);
+ return it.second ? it.first->second : Attribute();
}
/// Return the specified named attribute if present, None otherwise.
Optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
- auto *it = findAttr(attrs, name, isSorted());
- return it != attrs.end() ? *it : Optional<NamedAttribute>();
+ auto it = findAttr(*this, name);
+ return it.second ? *it.first : Optional<NamedAttribute>();
}
Optional<NamedAttribute> NamedAttrList::getNamed(Identifier name) const {
- auto *it = findAttr(attrs, name, isSorted());
- return it != attrs.end() ? *it : Optional<NamedAttribute>();
+ auto it = findAttr(*this, name);
+ return it.second ? *it.first : Optional<NamedAttribute>();
}
/// If the an attribute exists with the specified name, change it to the new
@@ -124,34 +106,36 @@ Optional<NamedAttribute> NamedAttrList::getNamed(Identifier name) const {
Attribute NamedAttrList::set(Identifier name, Attribute value) {
assert(value && "attributes may never be null");
- // Look for an existing value for the given name, and set it in-place.
- auto *it = findAttr(attrs, name, isSorted());
- if (it != attrs.end()) {
- // Only update if the value is
diff erent from the existing.
- Attribute oldValue = it->second;
- if (oldValue != value) {
+ // Look for an existing attribute with the given name, and set its value
+ // in-place. Return the previous value of the attribute, if there was one.
+ auto it = findAttr(*this, name);
+ if (it.second) {
+ // Update the existing attribute by swapping out the old value for the new
+ // value. Return the old value.
+ if (it.first->second != value) {
+ std::swap(it.first->second, value);
+ // If the attributes have changed, the dictionary is invalidated.
dictionarySorted.setPointer(nullptr);
- it->second = value;
}
- return oldValue;
+ return value;
}
-
- // Otherwise, insert the new attribute into its sorted position.
- it = llvm::lower_bound(attrs, name);
+ // Perform a string lookup to insert the new attribute into its sorted
+ // position.
+ if (isSorted())
+ it = findAttr(*this, name.strref());
+ attrs.insert(it.first, {name, value});
+ // Invalidate the dictionary. Return null as there was no previous value.
dictionarySorted.setPointer(nullptr);
- attrs.insert(it, {name, value});
return Attribute();
}
+
Attribute NamedAttrList::set(StringRef name, Attribute value) {
- assert(value && "setting null attribute not supported");
+ assert(value && "attributes may never be null");
return set(mlir::Identifier::get(name, value.getContext()), value);
}
Attribute
NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
- if (it == attrs.end())
- return nullptr;
-
// Erasing does not affect the sorted property.
Attribute attr = it->second;
attrs.erase(it);
@@ -160,11 +144,13 @@ NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
}
Attribute NamedAttrList::erase(Identifier name) {
- return eraseImpl(findAttr(attrs, name, isSorted()));
+ auto it = findAttr(*this, name);
+ return it.second ? eraseImpl(it.first) : Attribute();
}
Attribute NamedAttrList::erase(StringRef name) {
- return eraseImpl(findAttr(attrs, name, isSorted()));
+ auto it = findAttr(*this, name);
+ return it.second ? eraseImpl(it.first) : Attribute();
}
NamedAttrList &
More information about the Mlir-commits
mailing list