[Mlir-commits] [mlir] 10a80c4 - [mlir] Implement replacement of SymbolRefAttrs in Dialect attributes using SubElementAttr interface
Markus Böck
llvmlistbot at llvm.org
Thu Oct 28 10:08:26 PDT 2021
Author: Markus Böck
Date: 2021-10-28T19:08:20+02:00
New Revision: 10a80c44133223c99ea67c805835e24938d840ea
URL: https://github.com/llvm/llvm-project/commit/10a80c44133223c99ea67c805835e24938d840ea
DIFF: https://github.com/llvm/llvm-project/commit/10a80c44133223c99ea67c805835e24938d840ea.diff
LOG: [mlir] Implement replacement of SymbolRefAttrs in Dialect attributes using SubElementAttr interface
This patch extends the SubElementAttr interface to allow replacing a contained sub attribute. The attribute that should be replaced is identified by an index which denotes the n-th element returned by the accompanying walkImmediateSubElements method.
Using this addition the patch implements replacing SymbolRefAttrs contained within any dialect attributes.
Differential Revision: https://reviews.llvm.org/D111357
Added:
Modified:
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/SubElementInterfaces.td
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/test/IR/test-symbol-rauw.mlir
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/Dialect/Test/TestAttributes.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index fcd6082a6ef8..51ac32d9a564 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -71,7 +71,8 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
//===----------------------------------------------------------------------===//
def Builtin_ArrayAttr : Builtin_Attr<"Array", [
- DeclareAttrInterfaceMethods<SubElementAttrInterface>
+ DeclareAttrInterfaceMethods<SubElementAttrInterface,
+ ["replaceImmediateSubAttribute"]>
]> {
let summary = "A collection of other Attribute values";
let description = [{
@@ -345,7 +346,8 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
//===----------------------------------------------------------------------===//
def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
- DeclareAttrInterfaceMethods<SubElementAttrInterface>
+ DeclareAttrInterfaceMethods<SubElementAttrInterface,
+ ["replaceImmediateSubAttribute"]>
]> {
let summary = "An dictionary of named Attribute values";
let description = [{
@@ -954,10 +956,11 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
symbol nested within a
diff erent symbol table.
This attribute can only be held internally by
- [array attributes](#array-attribute) and
+ [array attributes](#array-attribute),
[dictionary attributes](#dictionary-attribute)(including the top-level
- operation attribute dictionary), i.e. no other attribute kinds such as
- Locations or extended attribute kinds.
+ operation attribute dictionary) as well as attributes exposing it via
+ the `SubElementAttrInterface` interface. Symbol reference attributes
+ nested in types are currently not supported.
**Rationale:** Identifying accesses to global data is critical to
enabling efficient multi-threaded compilation. Restricting global
diff --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td
index 8a4885237865..9ee90513cdda 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.td
+++ b/mlir/include/mlir/IR/SubElementInterfaces.td
@@ -33,6 +33,20 @@ class SubElementInterfaceBase<string interfaceName, string derivedValue> {
(ins "llvm::function_ref<void(mlir::Attribute)>":$walkAttrsFn,
"llvm::function_ref<void(mlir::Type)>":$walkTypesFn)
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Replace the attributes identified by the indices with the corresponding
+ value. The index is derived from the order of the attributes returned by
+ the attribute callback of `walkImmediateSubElements`. An index of 0 would
+ replace the very first attribute given by `walkImmediateSubElements`.
+ The new instance with the values replaced is returned.
+ }], cppNamespace # "::" # interfaceName, "replaceImmediateSubAttribute",
+ (ins "::llvm::ArrayRef<std::pair<size_t, ::mlir::Attribute>>":$replacements),
+ [{}],
+ /*defaultImplementation=*/[{
+ llvm_unreachable("Attribute or Type does not support replacing attributes");
+ }]
+ >,
];
code extraClassDeclaration = [{
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index fe8f6a54d82f..72891d995af6 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -53,6 +53,15 @@ void ArrayAttr::walkImmediateSubElements(
walkAttrsFn(attr);
}
+SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute(
+ ArrayRef<std::pair<size_t, Attribute>> replacements) const {
+ std::vector<Attribute> vector = getValue().vec();
+ for (auto &it : replacements) {
+ vector[it.first] = it.second;
+ }
+ return get(getContext(), vector);
+}
+
//===----------------------------------------------------------------------===//
// DictionaryAttr
//===----------------------------------------------------------------------===//
@@ -217,6 +226,17 @@ void DictionaryAttr::walkImmediateSubElements(
walkAttrsFn(attr);
}
+SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute(
+ ArrayRef<std::pair<size_t, Attribute>> replacements) const {
+ std::vector<NamedAttribute> vec = getValue().vec();
+ for (auto &it : replacements) {
+ vec[it.first].second = it.second;
+ }
+ // The above only modifies the mapped value, but not the key, and therefore
+ // not the order of the elements. It remains sorted
+ return getWithSorted(getContext(), vec);
+}
+
//===----------------------------------------------------------------------===//
// StringAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index ad4f08364e7b..6634eab4150e 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -485,16 +485,30 @@ static WalkResult walkSymbolRefs(
// A worklist of a container attribute and the current index into the held
// attribute list.
- SmallVector<Attribute, 1> attrWorklist(1, attrDict);
+ struct WorklistItem {
+ SubElementAttrInterface container;
+ SmallVector<Attribute> immediateSubElements;
+
+ explicit WorklistItem(SubElementAttrInterface container) {
+ SmallVector<Attribute> subElements;
+ container.walkImmediateSubElements(
+ [&](Attribute attr) { subElements.push_back(attr); }, [](Type) {});
+ immediateSubElements = std::move(subElements);
+ }
+ };
+
+ SmallVector<WorklistItem, 1> attrWorklist(1, WorklistItem(attrDict));
SmallVector<int, 1> curAccessChain(1, /*Value=*/-1);
// Process the symbol references within the given nested attribute range.
- auto processAttrs = [&](int &index, auto attrRange) -> WalkResult {
- for (Attribute attr : llvm::drop_begin(attrRange, index)) {
+ auto processAttrs = [&](int &index,
+ WorklistItem &worklistItem) -> WalkResult {
+ for (Attribute attr :
+ llvm::drop_begin(worklistItem.immediateSubElements, index)) {
/// Check for a nested container attribute, these will also need to be
/// walked.
- if (attr.isa<ArrayAttr, DictionaryAttr>()) {
- attrWorklist.push_back(attr);
+ if (auto interface = attr.dyn_cast<SubElementAttrInterface>()) {
+ attrWorklist.emplace_back(interface);
curAccessChain.push_back(-1);
return WalkResult::advance();
}
@@ -517,15 +531,12 @@ static WalkResult walkSymbolRefs(
WalkResult result = WalkResult::advance();
do {
- Attribute attr = attrWorklist.back();
+ WorklistItem &item = attrWorklist.back();
int &index = curAccessChain.back();
++index;
// Process the given attribute, which is guaranteed to be a container.
- if (auto dict = attr.dyn_cast<DictionaryAttr>())
- result = processAttrs(index, make_second_range(dict.getValue()));
- else
- result = processAttrs(index, attr.cast<ArrayAttr>().getValue());
+ result = processAttrs(index, item);
} while (!attrWorklist.empty() && !result.wasInterrupted());
return result;
}
@@ -811,48 +822,46 @@ bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
/// Rebuild the given attribute container after replacing all references to a
/// symbol with the updated attribute in 'accesses'.
-static Attribute rebuildAttrAfterRAUW(
- Attribute container,
+static SubElementAttrInterface rebuildAttrAfterRAUW(
+ SubElementAttrInterface container,
ArrayRef<std::pair<SmallVector<int, 1>, SymbolRefAttr>> accesses,
unsigned depth) {
// Given a range of Attributes, update the ones referred to by the given
// access chains to point to the new symbol attribute.
- auto updateAttrs = [&](auto &&attrRange) {
- auto attrBegin = std::begin(attrRange);
- for (unsigned i = 0, e = accesses.size(); i != e;) {
- ArrayRef<int> access = accesses[i].first;
- Attribute &attr = *std::next(attrBegin, access[depth]);
-
- // Check to see if this is a leaf access, i.e. a SymbolRef.
- if (access.size() == depth + 1) {
- attr = accesses[i].second;
- ++i;
- continue;
- }
- // Otherwise, this is a container. Collect all of the accesses for this
- // index and recurse. The recursion here is bounded by the size of the
- // largest access array.
- auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) {
- ArrayRef<int> nextAccess = it.first;
- return nextAccess.size() > depth + 1 &&
- nextAccess[depth] == access[depth];
- });
- attr = rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1);
-
- // Skip over all of the accesses that refer to the nested container.
- i += nestedAccesses.size();
+ SmallVector<std::pair<size_t, Attribute>> replacements;
+
+ SmallVector<Attribute> subElements;
+ container.walkImmediateSubElements(
+ [&](Attribute attribute) { subElements.push_back(attribute); },
+ [](Type) {});
+ for (unsigned i = 0, e = accesses.size(); i != e;) {
+ ArrayRef<int> access = accesses[i].first;
+
+ // Check to see if this is a leaf access, i.e. a SymbolRef.
+ if (access.size() == depth + 1) {
+ replacements.emplace_back(access.back(), accesses[i].second);
+ ++i;
+ continue;
}
- };
- if (auto dictAttr = container.dyn_cast<DictionaryAttr>()) {
- auto newAttrs = llvm::to_vector<4>(dictAttr.getValue());
- updateAttrs(make_second_range(newAttrs));
- return DictionaryAttr::get(dictAttr.getContext(), newAttrs);
+ // Otherwise, this is a container. Collect all of the accesses for this
+ // index and recurse. The recursion here is bounded by the size of the
+ // largest access array.
+ auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) {
+ ArrayRef<int> nextAccess = it.first;
+ return nextAccess.size() > depth + 1 &&
+ nextAccess[depth] == access[depth];
+ });
+ auto result = rebuildAttrAfterRAUW(subElements[access[depth]],
+ nestedAccesses, depth + 1);
+ replacements.emplace_back(access[depth], result);
+
+ // Skip over all of the accesses that refer to the nested container.
+ i += nestedAccesses.size();
}
- auto newAttrs = llvm::to_vector<4>(container.cast<ArrayAttr>().getValue());
- updateAttrs(newAttrs);
- return ArrayAttr::get(container.getContext(), newAttrs);
+
+ return container.replaceImmediateSubAttribute(replacements);
}
/// Generates a new symbol reference attribute with a new leaf reference.
diff --git a/mlir/test/IR/test-symbol-rauw.mlir b/mlir/test/IR/test-symbol-rauw.mlir
index 5d50bc0483b9..931c26b6b344 100644
--- a/mlir/test/IR/test-symbol-rauw.mlir
+++ b/mlir/test/IR/test-symbol-rauw.mlir
@@ -73,3 +73,24 @@ module {
"foo.possibly_unknown_symbol_table"() ({
}) : () -> ()
}
+
+// -----
+
+// Check that replacement works in any implementations of SubElementsAttrInterface
+module {
+ // CHECK: func private @replaced_foo
+ func private @symbol_foo() attributes {sym.new_name = "replaced_foo" }
+
+ // CHECK: func @symbol_bar
+ func @symbol_bar() {
+ // CHECK: foo.op
+ // CHECK-SAME: non_symbol_attr,
+ // CHECK-SAME: use = [#test.sub_elements_access<[@replaced_foo], @symbol_bar, @replaced_foo>],
+ // CHECK-SAME: z_non_symbol_attr_3
+ "foo.op"() {
+ non_symbol_attr,
+ use = [#test.sub_elements_access<[@symbol_foo], at symbol_bar, at symbol_foo>],
+ z_non_symbol_attr_3
+ } : () -> ()
+ }
+}
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 8e36f634fdc9..3062fd6c65ca 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -16,6 +16,7 @@
// To get the test dialect definition.
include "TestOps.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/SubElementInterfaces.td"
// All of the attributes will extend this class.
class Test_Attr<string name, list<Trait> traits = []>
@@ -101,4 +102,18 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
let genVerifyDecl = 1;
}
+def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
+ DeclareAttrInterfaceMethods<SubElementAttrInterface,
+ ["replaceImmediateSubAttribute"]>
+ ]> {
+
+ let mnemonic = "sub_elements_access";
+
+ let parameters = (ins
+ "::mlir::Attribute":$first,
+ "::mlir::Attribute":$second,
+ "::mlir::Attribute":$third
+ );
+}
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 29b702323856..9cd9c574a7bf 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -127,6 +127,57 @@ TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// TestSubElementsAccessAttr
+//===----------------------------------------------------------------------===//
+
+Attribute TestSubElementsAccessAttr::parse(::mlir::DialectAsmParser &parser,
+ ::mlir::Type type) {
+ Attribute first, second, third;
+ if (parser.parseLess() || parser.parseAttribute(first) ||
+ parser.parseComma() || parser.parseAttribute(second) ||
+ parser.parseComma() || parser.parseAttribute(third) ||
+ parser.parseGreater()) {
+ return {};
+ }
+ return get(parser.getContext(), first, second, third);
+}
+
+void TestSubElementsAccessAttr::print(
+ ::mlir::DialectAsmPrinter &printer) const {
+ printer << getMnemonic() << "<" << getFirst() << ", " << getSecond() << ", "
+ << getThird() << ">";
+}
+
+void TestSubElementsAccessAttr::walkImmediateSubElements(
+ llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
+ llvm::function_ref<void(mlir::Type)> walkTypesFn) const {
+ walkAttrsFn(getFirst());
+ walkAttrsFn(getSecond());
+ walkAttrsFn(getThird());
+}
+
+SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
+ ArrayRef<std::pair<size_t, Attribute>> replacements) const {
+ Attribute first = getFirst();
+ Attribute second = getSecond();
+ Attribute third = getThird();
+ for (auto &it : replacements) {
+ switch (it.first) {
+ case 0:
+ first = it.second;
+ break;
+ case 1:
+ second = it.second;
+ break;
+ case 2:
+ third = it.second;
+ break;
+ }
+ }
+ return get(getContext(), first, second, third);
+}
+
//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list