[Mlir-commits] [mlir] 00a52c7 - [mlir:SubElementsInterface] Add support for "skipping" when replacing attributes/types
River Riddle
llvmlistbot at llvm.org
Thu Jul 28 10:52:39 PDT 2022
Author: River Riddle
Date: 2022-07-28T10:52:12-07:00
New Revision: 00a52c75655bb352f875729a93c3f2ae990e5b78
URL: https://github.com/llvm/llvm-project/commit/00a52c75655bb352f875729a93c3f2ae990e5b78
DIFF: https://github.com/llvm/llvm-project/commit/00a52c75655bb352f875729a93c3f2ae990e5b78.diff
LOG: [mlir:SubElementsInterface] Add support for "skipping" when replacing attributes/types
This is used to fix a bug in SymbolTable::replaceAllSymbolUses where we replace symbols that
we shouldn't.
Differential Revision: https://reviews.llvm.org/D130693
Added:
Modified:
mlir/include/mlir/IR/SubElementInterfaces.h
mlir/include/mlir/IR/SubElementInterfaces.td
mlir/lib/IR/SubElementInterfaces.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/test/IR/test-symbol-rauw.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/SubElementInterfaces.h
index d0fc0ba921708..2c40e4edfa0fa 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.h
+++ b/mlir/include/mlir/IR/SubElementInterfaces.h
@@ -16,6 +16,14 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Types.h"
+#include "mlir/IR/Visitors.h"
+
+namespace mlir {
+template <typename T>
+using SubElementReplFn = function_ref<T(T)>;
+template <typename T>
+using SubElementResultReplFn = function_ref<std::pair<T, WalkResult>(T)>;
+} // namespace mlir
/// Include the definitions of the sub elemnt interfaces.
#include "mlir/IR/SubElementAttrInterfaces.h.inc"
diff --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td
index f1aede639cf09..e857beb379de3 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.td
+++ b/mlir/include/mlir/IR/SubElementInterfaces.td
@@ -56,8 +56,22 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
/// Recursively replace all of the nested sub-attributes and sub-types using the
/// provided map functions. Returns nullptr in the case of failure.
}] # attrOrType # [{ replaceSubElements(
- llvm::function_ref<mlir::Attribute(mlir::Attribute)> replaceAttrFn,
- llvm::function_ref<mlir::Type(mlir::Type)> replaceTypeFn
+ mlir::SubElementReplFn<mlir::Attribute> replaceAttrFn,
+ mlir::SubElementReplFn<mlir::Type> replaceTypeFn
+ ) {
+ return replaceSubElements(
+ [&](Attribute attr) { return std::make_pair(replaceAttrFn(attr), WalkResult::advance()); },
+ [&](Type type) { return std::make_pair(replaceTypeFn(type), WalkResult::advance()); }
+ );
+ }
+ /// Recursively replace all of the nested sub-attributes and sub-types using the
+ /// provided map functions. This variant allows for the map function to return an
+ /// additional walk result. Returns nullptr in the case of failure.
+ }] # attrOrType # [{ replaceSubElements(
+ llvm::function_ref<
+ std::pair<mlir::Attribute, mlir::WalkResult>(mlir::Attribute)> replaceAttrFn,
+ llvm::function_ref<
+ std::pair<mlir::Type, mlir::WalkResult>(mlir::Type)> replaceTypeFn
);
}];
code extraTraitClassDeclaration = [{
@@ -71,18 +85,16 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
/// Recursively replace all of the nested sub-attributes and sub-types using the
/// provided map functions. Returns nullptr in the case of failure.
}] # attrOrType # [{ replaceSubElements(
- llvm::function_ref<mlir::Attribute(mlir::Attribute)> replaceAttrFn,
- llvm::function_ref<mlir::Type(mlir::Type)> replaceTypeFn) {
+ mlir::SubElementReplFn<mlir::Attribute> replaceAttrFn,
+ mlir::SubElementReplFn<mlir::Type> replaceTypeFn) {
}] # interfaceName # " interface(" # derivedValue # [{);
return interface.replaceSubElements(replaceAttrFn, replaceTypeFn);
}
-
- /// Recursively replace all of the nested sub-attributes and sub-types using the
- /// provided map functions. Returns nullptr in the case of failure.
- }] # attrOrType # [{ replaceImmediateSubElements(
- llvm::ArrayRef<mlir::Attribute> replAttrs,
- llvm::function_ref<mlir::Type(mlir::Type)> replTypes) {
- return nullptr;
+ }] # attrOrType # [{ replaceSubElements(
+ mlir::SubElementResultReplFn<mlir::Attribute> replaceAttrFn,
+ mlir::SubElementResultReplFn<mlir::Type> replaceTypeFn) {
+ }] # interfaceName # " interface(" # derivedValue # [{);
+ return interface.replaceSubElements(replaceAttrFn, replaceTypeFn);
}
}];
code extraSharedClassDeclaration = [{
@@ -98,17 +110,31 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
/// Recursively replace all of the nested sub-attributes using the provided
/// map function. Returns nullptr in the case of failure.
}] # attrOrType # [{ replaceSubElements(
- llvm::function_ref<mlir::Attribute(mlir::Attribute)> replaceAttrFn) {
+ mlir::SubElementReplFn<mlir::Attribute> replaceAttrFn) {
return replaceSubElements(
replaceAttrFn, [](mlir::Type type) { return type; });
}
+ }] # attrOrType # [{ replaceSubElements(
+ mlir::SubElementResultReplFn<mlir::Attribute> replaceAttrFn) {
+ return replaceSubElements(
+ replaceAttrFn,
+ [](mlir::Type type) { return std::make_pair(type, WalkResult::advance()); }
+ );
+ }
/// Recursively replace all of the nested sub-types using the provided map
/// function. Returns nullptr in the case of failure.
}] # attrOrType # [{ replaceSubElements(
- llvm::function_ref<mlir::Type(mlir::Type)> replaceTypeFn) {
+ mlir::SubElementReplFn<mlir::Type> replaceTypeFn) {
return replaceSubElements(
[](mlir::Attribute attr) { return attr; }, replaceTypeFn);
}
+ }] # attrOrType # [{ replaceSubElements(
+ mlir::SubElementResultReplFn<mlir::Type> replaceTypeFn) {
+ return replaceSubElements(
+ [](mlir::Attribute attr) { return std::make_pair(attr, WalkResult::advance()); },
+ replaceTypeFn
+ );
+ }
}];
}
diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp
index f8d47083f11c4..f8526dc6d3869 100644
--- a/mlir/lib/IR/SubElementInterfaces.cpp
+++ b/mlir/lib/IR/SubElementInterfaces.cpp
@@ -102,11 +102,10 @@ static bool isMutable(Type type) {
}
template <typename InterfaceT, typename T, typename ReplaceSubElementFnT>
-static void updateSubElementImpl(T element, function_ref<T(T)> walkFn,
- DenseMap<T, T> &visited,
- SmallVectorImpl<T> &newElements,
- FailureOr<bool> &changed,
- ReplaceSubElementFnT &&replaceSubElementFn) {
+static void updateSubElementImpl(
+ T element, function_ref<std::pair<T, WalkResult>(T)> walkFn,
+ DenseMap<T, T> &visited, SmallVectorImpl<T> &newElements,
+ FailureOr<bool> &changed, ReplaceSubElementFnT &&replaceSubElementFn) {
// Bail early if we failed at any point.
if (failed(changed))
return;
@@ -120,17 +119,22 @@ static void updateSubElementImpl(T element, function_ref<T(T)> walkFn,
// yet.
T &mappedElement = visited[element];
if (!mappedElement) {
+ WalkResult result = WalkResult::advance();
+ std::tie(mappedElement, result) = walkFn(element);
+
// Try walking this element.
- if (!(mappedElement = walkFn(element))) {
+ if (result.wasInterrupted() || !mappedElement) {
changed = failure();
return;
}
// Handle replacing sub-elements if this element is also a container.
- if (auto interface = mappedElement.template dyn_cast<InterfaceT>()) {
- if (!(mappedElement = replaceSubElementFn(interface))) {
- changed = failure();
- return;
+ if (!result.wasSkipped()) {
+ if (auto interface = mappedElement.template dyn_cast<InterfaceT>()) {
+ if (!(mappedElement = replaceSubElementFn(interface))) {
+ changed = failure();
+ return;
+ }
}
}
}
@@ -145,8 +149,8 @@ static void updateSubElementImpl(T element, function_ref<T(T)> walkFn,
template <typename InterfaceT>
static typename InterfaceT::ValueType
replaceSubElementsImpl(InterfaceT interface,
- function_ref<Attribute(Attribute)> walkAttrsFn,
- function_ref<Type(Type)> walkTypesFn,
+ SubElementResultReplFn<Attribute> walkAttrsFn,
+ SubElementResultReplFn<Type> walkTypesFn,
DenseMap<Attribute, Attribute> &visitedAttrs,
DenseMap<Type, Type> &visitedTypes) {
// Walk the current sub-elements, replacing them as necessary.
@@ -186,8 +190,8 @@ replaceSubElementsImpl(InterfaceT interface,
}
Attribute SubElementAttrInterface::replaceSubElements(
- function_ref<Attribute(Attribute)> replaceAttrFn,
- function_ref<Type(Type)> replaceTypeFn) {
+ SubElementResultReplFn<Attribute> replaceAttrFn,
+ SubElementResultReplFn<Type> replaceTypeFn) {
assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
DenseMap<Attribute, Attribute> visitedAttrs;
DenseMap<Type, Type> visitedTypes;
@@ -196,8 +200,8 @@ Attribute SubElementAttrInterface::replaceSubElements(
}
Type SubElementTypeInterface::replaceSubElements(
- function_ref<Attribute(Attribute)> replaceAttrFn,
- function_ref<Type(Type)> replaceTypeFn) {
+ SubElementResultReplFn<Attribute> replaceAttrFn,
+ SubElementResultReplFn<Type> replaceTypeFn) {
assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
DenseMap<Attribute, Attribute> visitedAttrs;
DenseMap<Type, Type> visitedTypes;
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index fb56d91f68a6c..792bf42488372 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -853,23 +853,30 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
auto walkFn = [&](Operation *op) -> Optional<WalkResult> {
- auto remapAttrFn = [&](Attribute attr) -> Attribute {
+ auto remapAttrFn =
+ [&](Attribute attr) -> std::pair<Attribute, WalkResult> {
+ // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
+ // want to accidentally replace an inner reference.
if (attr == oldAttr)
- return newAttr;
+ return {newAttr, WalkResult::skip()};
// Handle prefix matches.
if (SymbolRefAttr symRef = attr.dyn_cast<SymbolRefAttr>()) {
if (isReferencePrefixOf(oldAttr, symRef)) {
auto oldNestedRefs = oldAttr.getNestedReferences();
auto nestedRefs = symRef.getNestedReferences();
if (oldNestedRefs.empty())
- return SymbolRefAttr::get(newSymbol, nestedRefs);
+ return {SymbolRefAttr::get(newSymbol, nestedRefs),
+ WalkResult::skip()};
auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
- return SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs);
+ return {
+ SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs),
+ WalkResult::skip()};
}
+ return {attr, WalkResult::skip()};
}
- return attr;
+ return {attr, WalkResult::advance()};
};
// Generate a new attribute dictionary by replacing references to the old
// symbol.
diff --git a/mlir/test/IR/test-symbol-rauw.mlir b/mlir/test/IR/test-symbol-rauw.mlir
index b33651f1129d6..c7d48b6c4eb1d 100644
--- a/mlir/test/IR/test-symbol-rauw.mlir
+++ b/mlir/test/IR/test-symbol-rauw.mlir
@@ -94,3 +94,19 @@ module {
} : () -> ()
}
}
+
+// -----
+
+module {
+ // CHECK: module @replaced_foo
+ module @foo attributes {sym.new_name = "replaced_foo" } {
+ // CHECK: func.func private @foo
+ func.func private @foo()
+ }
+
+ // CHECK: foo.op
+ // CHECK-SAME: use = @replaced_foo::@foo
+ "foo.op"() {
+ use = @foo::@foo
+ } : () -> ()
+}
More information about the Mlir-commits
mailing list