[Mlir-commits] [mlir] 0e2bd49 - [mlir][DictionaryAttr] Add a new getWithSorted and use it when possible

River Riddle llvmlistbot at llvm.org
Fri Apr 24 12:24:06 PDT 2020


Author: River Riddle
Date: 2020-04-24T12:23:32-07:00
New Revision: 0e2bd49370197dd8bf2c36ee0ce1275f7cfb515b

URL: https://github.com/llvm/llvm-project/commit/0e2bd49370197dd8bf2c36ee0ce1275f7cfb515b
DIFF: https://github.com/llvm/llvm-project/commit/0e2bd49370197dd8bf2c36ee0ce1275f7cfb515b.diff

LOG: [mlir][DictionaryAttr] Add a new getWithSorted and use it when possible

The elements of a DictionaryAttr are sorted by name. In many situations, e.g NamedAttributeList, we can guarantee that the elements are sorted on construction and remove the need to perform extra checks. In places with lots of calls to attribute methods, this leads to a good performance improvement.

Differential Revision: https://reviews.llvm.org/D78781

Added: 
    

Modified: 
    mlir/include/mlir/IR/Attributes.h
    mlir/lib/IR/Attributes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index e65f4a0b0624..656c28ba4e8a 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -278,9 +278,18 @@ class DictionaryAttr
   using Base::Base;
   using ValueType = ArrayRef<NamedAttribute>;
 
+  /// Construct a dictionary attribute with the provided list of named
+  /// attributes. This method assumes that the provided list is unordered. If
+  /// the caller can guarantee that the attributes are ordered by name,
+  /// getWithSorted should be used instead.
   static DictionaryAttr get(ArrayRef<NamedAttribute> value,
                             MLIRContext *context);
 
+  /// Construct a dictionary with an array of values that is known to already be
+  /// sorted by name and uniqued.
+  static DictionaryAttr getWithSorted(ArrayRef<NamedAttribute> value,
+                                      MLIRContext *context);
+
   ArrayRef<NamedAttribute> getValue() const;
 
   /// Return the specified attribute if present, null otherwise.
@@ -1455,8 +1464,8 @@ inline ::llvm::hash_code hash_value(Attribute arg) {
 // NamedAttributeList
 //===----------------------------------------------------------------------===//
 
-/// A NamedAttributeList is used to manage a list of named attributes. This
-/// provides simple interfaces for adding/removing/finding attributes from
+/// A NamedAttributeList is a mutable wrapper around a DictionaryAttr. It
+/// provides additional interfaces for adding, removing, replacing attributes
 /// within a DictionaryAttr.
 ///
 /// We assume there will be relatively few attributes on a given operation

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 26cb0513c099..a380bc7b22c5 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -90,6 +90,17 @@ static int compareNamedAttributes(const NamedAttribute *lhs,
   return strcmp(lhs->first.data(), rhs->first.data());
 }
 
+/// Returns if the name of the given attribute precedes that of 'name'.
+static bool compareNamedAttributeWithName(const NamedAttribute &attr,
+                                          StringRef name) {
+  // This is correct even when attr.first.data()[name.size()] is not a zero
+  // string terminator, because we only care about a less than comparison.
+  // This can't use memcmp, because it doesn't guarantee that it will stop
+  // reading both buffers if one is shorter than the other, even if there is
+  // a 
diff erence.
+  return strncmp(attr.first.data(), name.data(), name.size()) < 0;
+}
+
 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
                                    MLIRContext *context) {
   assert(llvm::all_of(value,
@@ -145,6 +156,24 @@ DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
   return Base::get(context, StandardAttributes::Dictionary, value);
 }
 
+/// Construct a dictionary with an array of values that is known to already be
+/// sorted by name and uniqued.
+DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
+                                             MLIRContext *context) {
+  // Ensure that the attribute elements are unique and sorted.
+  assert(llvm::is_sorted(value,
+                         [](NamedAttribute l, NamedAttribute r) {
+                           return l.first.strref() < r.first.strref();
+                         }) &&
+         "expected attribute values to be sorted");
+  assert(std::adjacent_find(value.begin(), value.end(),
+                            [](NamedAttribute l, NamedAttribute r) {
+                              return l.first == r.first;
+                            }) == value.end() &&
+         "DictionaryAttr element names must be unique");
+  return Base::get(context, StandardAttributes::Dictionary, value);
+}
+
 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
   return getImpl()->getElements();
 }
@@ -152,15 +181,7 @@ ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
 /// Return the specified attribute if present, null otherwise.
 Attribute DictionaryAttr::get(StringRef name) const {
   ArrayRef<NamedAttribute> values = getValue();
-  auto compare = [](NamedAttribute attr, StringRef name) -> bool {
-    // This is correct even when attr.first.data()[name.size()] is not a zero
-    // string terminator, because we only care about a less than comparison.
-    // This can't use memcmp, because it doesn't guarantee that it will stop
-    // reading both buffers if one is shorter than the other, even if there is
-    // a 
diff erence.
-    return strncmp(attr.first.data(), name.data(), name.size()) < 0;
-  };
-  auto it = llvm::lower_bound(values, name, compare);
+  auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
   return it != values.end() && it->first == name ? it->second : Attribute();
 }
 Attribute DictionaryAttr::get(Identifier name) const {
@@ -1158,19 +1179,29 @@ Attribute NamedAttributeList::get(Identifier name) const {
 void NamedAttributeList::set(Identifier name, Attribute value) {
   assert(value && "attributes may never be null");
 
-  // If we already have this attribute, replace it.
-  auto origAttrs = getAttrs();
-  SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end());
-  for (auto &elt : newAttrs)
-    if (elt.first == name) {
-      elt.second = value;
-      attrs = DictionaryAttr::get(newAttrs, value.getContext());
+  // Look for an existing value for the given name, and set it in-place.
+  ArrayRef<NamedAttribute> values = getAttrs();
+  auto it = llvm::find_if(
+      values, [name](NamedAttribute attr) { return attr.first == name; });
+  if (it != values.end()) {
+    // Bail out early if the value is the same as what we already have.
+    if (it->second == value)
       return;
-    }
 
-  // Otherwise, add it.
+    SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end());
+    newAttrs[it - values.begin()].second = value;
+    attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
+    return;
+  }
+
+  // Otherwise, insert the new attribute into its sorted position.
+  it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
+  SmallVector<NamedAttribute, 8> newAttrs;
+  newAttrs.reserve(values.size() + 1);
+  newAttrs.append(values.begin(), it);
   newAttrs.push_back({name, value});
-  attrs = DictionaryAttr::get(newAttrs, value.getContext());
+  newAttrs.append(it, values.end());
+  attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
 }
 
 /// Remove the attribute with the specified name if it exists.  The return
@@ -1189,7 +1220,8 @@ auto NamedAttributeList::remove(Identifier name) -> RemoveResult {
       newAttrs.reserve(origAttrs.size() - 1);
       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
-      attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext());
+      attrs = DictionaryAttr::getWithSorted(newAttrs,
+                                            newAttrs[0].second.getContext());
       return RemoveResult::Removed;
     }
   }


        


More information about the Mlir-commits mailing list