[Mlir-commits] [mlir] 10a80c4 - [mlir] Implement replacement of SymbolRefAttrs in Dialect attributes using SubElementAttr interface

Markus Böck llvmlistbot at llvm.org
Thu Oct 28 10:08:26 PDT 2021


Author: Markus Böck
Date: 2021-10-28T19:08:20+02:00
New Revision: 10a80c44133223c99ea67c805835e24938d840ea

URL: https://github.com/llvm/llvm-project/commit/10a80c44133223c99ea67c805835e24938d840ea
DIFF: https://github.com/llvm/llvm-project/commit/10a80c44133223c99ea67c805835e24938d840ea.diff

LOG: [mlir] Implement replacement of SymbolRefAttrs in Dialect attributes using SubElementAttr interface

This patch extends the SubElementAttr interface to allow replacing a contained sub attribute. The attribute that should be replaced is identified by an index which denotes the n-th element returned by the accompanying walkImmediateSubElements method.

Using this addition the patch implements replacing SymbolRefAttrs contained within any dialect attributes.

Differential Revision: https://reviews.llvm.org/D111357

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/SubElementInterfaces.td
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/SymbolTable.cpp
    mlir/test/IR/test-symbol-rauw.mlir
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestAttributes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index fcd6082a6ef8..51ac32d9a564 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -71,7 +71,8 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
 //===----------------------------------------------------------------------===//
 
 def Builtin_ArrayAttr : Builtin_Attr<"Array", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    DeclareAttrInterfaceMethods<SubElementAttrInterface,
+        ["replaceImmediateSubAttribute"]>
   ]> {
   let summary = "A collection of other Attribute values";
   let description = [{
@@ -345,7 +346,8 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
 //===----------------------------------------------------------------------===//
 
 def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    DeclareAttrInterfaceMethods<SubElementAttrInterface,
+        ["replaceImmediateSubAttribute"]>
   ]> {
   let summary = "An dictionary of named Attribute values";
   let description = [{
@@ -954,10 +956,11 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
     symbol nested within a 
diff erent symbol table.
 
     This attribute can only be held internally by
-    [array attributes](#array-attribute) and
+    [array attributes](#array-attribute),
     [dictionary attributes](#dictionary-attribute)(including the top-level
-    operation attribute dictionary), i.e. no other attribute kinds such as
-    Locations or extended attribute kinds.
+    operation attribute dictionary) as well as attributes exposing it via
+    the `SubElementAttrInterface` interface. Symbol reference attributes
+    nested in types are currently not supported.
 
     **Rationale:** Identifying accesses to global data is critical to
     enabling efficient multi-threaded compilation. Restricting global

diff  --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td
index 8a4885237865..9ee90513cdda 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.td
+++ b/mlir/include/mlir/IR/SubElementInterfaces.td
@@ -33,6 +33,20 @@ class SubElementInterfaceBase<string interfaceName, string derivedValue> {
       (ins "llvm::function_ref<void(mlir::Attribute)>":$walkAttrsFn,
            "llvm::function_ref<void(mlir::Type)>":$walkTypesFn)
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Replace the attributes identified by the indices with the corresponding
+        value. The index is derived from the order of the attributes returned by
+        the attribute callback of `walkImmediateSubElements`. An index of 0 would
+        replace the very first attribute given by `walkImmediateSubElements`.
+        The new instance with the values replaced is returned.
+      }], cppNamespace # "::" # interfaceName, "replaceImmediateSubAttribute",
+      (ins "::llvm::ArrayRef<std::pair<size_t, ::mlir::Attribute>>":$replacements),
+      [{}],
+      /*defaultImplementation=*/[{
+        llvm_unreachable("Attribute or Type does not support replacing attributes");
+      }]
+    >,
   ];
 
   code extraClassDeclaration = [{

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index fe8f6a54d82f..72891d995af6 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -53,6 +53,15 @@ void ArrayAttr::walkImmediateSubElements(
     walkAttrsFn(attr);
 }
 
+SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute(
+    ArrayRef<std::pair<size_t, Attribute>> replacements) const {
+  std::vector<Attribute> vector = getValue().vec();
+  for (auto &it : replacements) {
+    vector[it.first] = it.second;
+  }
+  return get(getContext(), vector);
+}
+
 //===----------------------------------------------------------------------===//
 // DictionaryAttr
 //===----------------------------------------------------------------------===//
@@ -217,6 +226,17 @@ void DictionaryAttr::walkImmediateSubElements(
     walkAttrsFn(attr);
 }
 
+SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute(
+    ArrayRef<std::pair<size_t, Attribute>> replacements) const {
+  std::vector<NamedAttribute> vec = getValue().vec();
+  for (auto &it : replacements) {
+    vec[it.first].second = it.second;
+  }
+  // The above only modifies the mapped value, but not the key, and therefore
+  // not the order of the elements. It remains sorted
+  return getWithSorted(getContext(), vec);
+}
+
 //===----------------------------------------------------------------------===//
 // StringAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index ad4f08364e7b..6634eab4150e 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -485,16 +485,30 @@ static WalkResult walkSymbolRefs(
 
   // A worklist of a container attribute and the current index into the held
   // attribute list.
-  SmallVector<Attribute, 1> attrWorklist(1, attrDict);
+  struct WorklistItem {
+    SubElementAttrInterface container;
+    SmallVector<Attribute> immediateSubElements;
+
+    explicit WorklistItem(SubElementAttrInterface container) {
+      SmallVector<Attribute> subElements;
+      container.walkImmediateSubElements(
+          [&](Attribute attr) { subElements.push_back(attr); }, [](Type) {});
+      immediateSubElements = std::move(subElements);
+    }
+  };
+
+  SmallVector<WorklistItem, 1> attrWorklist(1, WorklistItem(attrDict));
   SmallVector<int, 1> curAccessChain(1, /*Value=*/-1);
 
   // Process the symbol references within the given nested attribute range.
-  auto processAttrs = [&](int &index, auto attrRange) -> WalkResult {
-    for (Attribute attr : llvm::drop_begin(attrRange, index)) {
+  auto processAttrs = [&](int &index,
+                          WorklistItem &worklistItem) -> WalkResult {
+    for (Attribute attr :
+         llvm::drop_begin(worklistItem.immediateSubElements, index)) {
       /// Check for a nested container attribute, these will also need to be
       /// walked.
-      if (attr.isa<ArrayAttr, DictionaryAttr>()) {
-        attrWorklist.push_back(attr);
+      if (auto interface = attr.dyn_cast<SubElementAttrInterface>()) {
+        attrWorklist.emplace_back(interface);
         curAccessChain.push_back(-1);
         return WalkResult::advance();
       }
@@ -517,15 +531,12 @@ static WalkResult walkSymbolRefs(
 
   WalkResult result = WalkResult::advance();
   do {
-    Attribute attr = attrWorklist.back();
+    WorklistItem &item = attrWorklist.back();
     int &index = curAccessChain.back();
     ++index;
 
     // Process the given attribute, which is guaranteed to be a container.
-    if (auto dict = attr.dyn_cast<DictionaryAttr>())
-      result = processAttrs(index, make_second_range(dict.getValue()));
-    else
-      result = processAttrs(index, attr.cast<ArrayAttr>().getValue());
+    result = processAttrs(index, item);
   } while (!attrWorklist.empty() && !result.wasInterrupted());
   return result;
 }
@@ -811,48 +822,46 @@ bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
 
 /// Rebuild the given attribute container after replacing all references to a
 /// symbol with the updated attribute in 'accesses'.
-static Attribute rebuildAttrAfterRAUW(
-    Attribute container,
+static SubElementAttrInterface rebuildAttrAfterRAUW(
+    SubElementAttrInterface container,
     ArrayRef<std::pair<SmallVector<int, 1>, SymbolRefAttr>> accesses,
     unsigned depth) {
   // Given a range of Attributes, update the ones referred to by the given
   // access chains to point to the new symbol attribute.
-  auto updateAttrs = [&](auto &&attrRange) {
-    auto attrBegin = std::begin(attrRange);
-    for (unsigned i = 0, e = accesses.size(); i != e;) {
-      ArrayRef<int> access = accesses[i].first;
-      Attribute &attr = *std::next(attrBegin, access[depth]);
-
-      // Check to see if this is a leaf access, i.e. a SymbolRef.
-      if (access.size() == depth + 1) {
-        attr = accesses[i].second;
-        ++i;
-        continue;
-      }
 
-      // Otherwise, this is a container. Collect all of the accesses for this
-      // index and recurse. The recursion here is bounded by the size of the
-      // largest access array.
-      auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) {
-        ArrayRef<int> nextAccess = it.first;
-        return nextAccess.size() > depth + 1 &&
-               nextAccess[depth] == access[depth];
-      });
-      attr = rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1);
-
-      // Skip over all of the accesses that refer to the nested container.
-      i += nestedAccesses.size();
+  SmallVector<std::pair<size_t, Attribute>> replacements;
+
+  SmallVector<Attribute> subElements;
+  container.walkImmediateSubElements(
+      [&](Attribute attribute) { subElements.push_back(attribute); },
+      [](Type) {});
+  for (unsigned i = 0, e = accesses.size(); i != e;) {
+    ArrayRef<int> access = accesses[i].first;
+
+    // Check to see if this is a leaf access, i.e. a SymbolRef.
+    if (access.size() == depth + 1) {
+      replacements.emplace_back(access.back(), accesses[i].second);
+      ++i;
+      continue;
     }
-  };
 
-  if (auto dictAttr = container.dyn_cast<DictionaryAttr>()) {
-    auto newAttrs = llvm::to_vector<4>(dictAttr.getValue());
-    updateAttrs(make_second_range(newAttrs));
-    return DictionaryAttr::get(dictAttr.getContext(), newAttrs);
+    // Otherwise, this is a container. Collect all of the accesses for this
+    // index and recurse. The recursion here is bounded by the size of the
+    // largest access array.
+    auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) {
+      ArrayRef<int> nextAccess = it.first;
+      return nextAccess.size() > depth + 1 &&
+             nextAccess[depth] == access[depth];
+    });
+    auto result = rebuildAttrAfterRAUW(subElements[access[depth]],
+                                       nestedAccesses, depth + 1);
+    replacements.emplace_back(access[depth], result);
+
+    // Skip over all of the accesses that refer to the nested container.
+    i += nestedAccesses.size();
   }
-  auto newAttrs = llvm::to_vector<4>(container.cast<ArrayAttr>().getValue());
-  updateAttrs(newAttrs);
-  return ArrayAttr::get(container.getContext(), newAttrs);
+
+  return container.replaceImmediateSubAttribute(replacements);
 }
 
 /// Generates a new symbol reference attribute with a new leaf reference.

diff  --git a/mlir/test/IR/test-symbol-rauw.mlir b/mlir/test/IR/test-symbol-rauw.mlir
index 5d50bc0483b9..931c26b6b344 100644
--- a/mlir/test/IR/test-symbol-rauw.mlir
+++ b/mlir/test/IR/test-symbol-rauw.mlir
@@ -73,3 +73,24 @@ module {
   "foo.possibly_unknown_symbol_table"() ({
   }) : () -> ()
 }
+
+// -----
+
+// Check that replacement works in any implementations of SubElementsAttrInterface
+module {
+    // CHECK: func private @replaced_foo
+    func private @symbol_foo() attributes {sym.new_name = "replaced_foo" }
+
+    // CHECK: func @symbol_bar
+    func @symbol_bar() {
+      // CHECK: foo.op
+      // CHECK-SAME: non_symbol_attr,
+      // CHECK-SAME: use = [#test.sub_elements_access<[@replaced_foo], @symbol_bar, @replaced_foo>],
+      // CHECK-SAME: z_non_symbol_attr_3
+      "foo.op"() {
+        non_symbol_attr,
+        use = [#test.sub_elements_access<[@symbol_foo], at symbol_bar, at symbol_foo>],
+        z_non_symbol_attr_3
+      } : () -> ()
+    }
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 8e36f634fdc9..3062fd6c65ca 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -16,6 +16,7 @@
 // To get the test dialect definition.
 include "TestOps.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/SubElementInterfaces.td"
 
 // All of the attributes will extend this class.
 class Test_Attr<string name, list<Trait> traits = []>
@@ -101,4 +102,18 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
   let genVerifyDecl = 1;
 }
 
+def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
+    DeclareAttrInterfaceMethods<SubElementAttrInterface,
+        ["replaceImmediateSubAttribute"]>
+  ]> {
+
+  let mnemonic = "sub_elements_access";
+
+  let parameters = (ins
+    "::mlir::Attribute":$first,
+    "::mlir::Attribute":$second,
+    "::mlir::Attribute":$third
+  );
+}
+
 #endif // TEST_ATTRDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 29b702323856..9cd9c574a7bf 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -127,6 +127,57 @@ TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestSubElementsAccessAttr
+//===----------------------------------------------------------------------===//
+
+Attribute TestSubElementsAccessAttr::parse(::mlir::DialectAsmParser &parser,
+                                           ::mlir::Type type) {
+  Attribute first, second, third;
+  if (parser.parseLess() || parser.parseAttribute(first) ||
+      parser.parseComma() || parser.parseAttribute(second) ||
+      parser.parseComma() || parser.parseAttribute(third) ||
+      parser.parseGreater()) {
+    return {};
+  }
+  return get(parser.getContext(), first, second, third);
+}
+
+void TestSubElementsAccessAttr::print(
+    ::mlir::DialectAsmPrinter &printer) const {
+  printer << getMnemonic() << "<" << getFirst() << ", " << getSecond() << ", "
+          << getThird() << ">";
+}
+
+void TestSubElementsAccessAttr::walkImmediateSubElements(
+    llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
+    llvm::function_ref<void(mlir::Type)> walkTypesFn) const {
+  walkAttrsFn(getFirst());
+  walkAttrsFn(getSecond());
+  walkAttrsFn(getThird());
+}
+
+SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
+    ArrayRef<std::pair<size_t, Attribute>> replacements) const {
+  Attribute first = getFirst();
+  Attribute second = getSecond();
+  Attribute third = getThird();
+  for (auto &it : replacements) {
+    switch (it.first) {
+    case 0:
+      first = it.second;
+      break;
+    case 1:
+      second = it.second;
+      break;
+    case 2:
+      third = it.second;
+      break;
+    }
+  }
+  return get(getContext(), first, second, third);
+}
+
 //===----------------------------------------------------------------------===//
 // Tablegen Generated Definitions
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list