[Mlir-commits] [mlir] e50941b - [mlir] Add a new AttrTypeReplacer class to simplify sub element replacements
River Riddle
llvmlistbot at llvm.org
Sat Nov 12 15:06:00 PST 2022
Author: River Riddle
Date: 2022-11-12T14:38:45-08:00
New Revision: e50941b8d7bad9af091286afe45403b554824ce3
URL: https://github.com/llvm/llvm-project/commit/e50941b8d7bad9af091286afe45403b554824ce3
DIFF: https://github.com/llvm/llvm-project/commit/e50941b8d7bad9af091286afe45403b554824ce3.diff
LOG: [mlir] Add a new AttrTypeReplacer class to simplify sub element replacements
We currently only have the SubElement interface API for attribute/type
replacement, but this suffers from several issues; namely that it doesn't
allow caching across multiple replacements (very common), and also
creates a somewhat awkward/limited API. The new AttrTypeReplacer class
allows for registering replacements using a much cleaner API, similarly to
the TypeConverter class, removes a lot of manual interaction with the
sub element interfaces, and also better enables large scale replacements.
Differential Revision: https://reviews.llvm.org/D137764
Added:
Modified:
mlir/include/mlir/IR/SubElementInterfaces.h
mlir/include/mlir/IR/SubElementInterfaces.td
mlir/lib/IR/SubElementInterfaces.cpp
mlir/lib/IR/SymbolTable.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/SubElementInterfaces.h
index 2af7642e93b25..016269282c22e 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.h
+++ b/mlir/include/mlir/IR/SubElementInterfaces.h
@@ -19,10 +19,114 @@
#include "mlir/IR/Visitors.h"
namespace mlir {
-template <typename T>
-using SubElementReplFn = function_ref<T(T)>;
-template <typename T>
-using SubElementResultReplFn = function_ref<std::pair<T, WalkResult>(T)>;
+//===----------------------------------------------------------------------===//
+/// AttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+/// This class provides a utility for replacing attributes/types, and their sub
+/// elements. Multiple replacement functions may be registered.
+class AttrTypeReplacer {
+public:
+ //===--------------------------------------------------------------------===//
+ // Application
+ //===--------------------------------------------------------------------===//
+
+ /// Replace the elements within the given operation. By default this includes
+ /// the attributes within the operation. If `replaceLocs` is true, this also
+ /// updates its location, the locations of any nested block arguments. If
+ /// `replaceTypes` is true, this also updates the result types of the
+ /// operation, and the types of any nested block arguments.
+ void replaceElementsIn(Operation *op, bool replaceLocs = false,
+ bool replaceTypes = false);
+
+ /// Replace the given attribute/type, and recursively replace any sub
+ /// elements. Returns either the new attribute/type, or nullptr in the case of
+ /// failure.
+ Attribute replace(Attribute attr);
+ Type replace(Type type);
+
+ //===--------------------------------------------------------------------===//
+ // Registration
+ //===--------------------------------------------------------------------===//
+
+ /// A replacement mapping function, which returns either None (to signal the
+ /// element wasn't handled), or a pair of the replacement element and a
+ /// WalkResult.
+ template <typename T>
+ using ReplaceFnResult = Optional<std::pair<T, WalkResult>>;
+ template <typename T>
+ using ReplaceFn = std::function<ReplaceFnResult<T>(T)>;
+
+ /// Register a replacement function for mapping a given attribute or type. A
+ /// replacement function must be convertible to any of the following
+ /// forms(where `T` is a class derived from `Type` or `Attribute`, and `BaseT`
+ /// is either `Type` or `Attribute` respectively):
+ ///
+ /// * Optional<BaseT>(T)
+ /// - This either returns a valid Attribute/Type in the case of success,
+ /// nullptr in the case of failure, or `llvm::None` to signify that
+ /// additional replacement functions may be applied (i.e. this function
+ /// doesn't handle that instance).
+ ///
+ /// * Optional<std::pair<BaseT, WalkResult>>(T)
+ /// - Similar to the above, but also allows specifying a WalkResult to
+ /// control the replacement of sub elements of a given attribute or
+ /// type. Returning a `skip` result, for example, will not recursively
+ /// process the resultant attribute or type value.
+ ///
+ /// Note: When replacing, the mostly recently added replacement functions will
+ /// be invoked first.
+ void addReplacement(ReplaceFn<Attribute> fn) {
+ attrReplacementFns.emplace_back(std::move(fn));
+ }
+ void addReplacement(ReplaceFn<Type> fn) {
+ typeReplacementFns.push_back(std::move(fn));
+ }
+
+ /// Register a replacement function that doesn't match the default signature,
+ /// either because it uses a derived parameter type, or it uses a simplified
+ /// result type.
+ template <typename FnT,
+ typename T = typename llvm::function_traits<
+ std::decay_t<FnT>>::template arg_t<0>,
+ typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
+ Attribute, Type>,
+ typename ResultT = std::invoke_result_t<FnT, T>>
+ std::enable_if_t<!std::is_same_v<T, BaseT> ||
+ !std::is_convertible_v<ResultT, ReplaceFnResult<BaseT>>>
+ addReplacement(FnT &&callback) {
+ addReplacement([callback = std::forward<FnT>(callback)](
+ BaseT base) -> ReplaceFnResult<BaseT> {
+ if (auto derived = dyn_cast<T>(base)) {
+ if constexpr (std::is_convertible_v<ResultT, Optional<BaseT>>) {
+ Optional<BaseT> result = callback(derived);
+ return result ? std::make_pair(*result, WalkResult::advance())
+ : ReplaceFnResult<BaseT>();
+ } else {
+ return callback(derived);
+ }
+ }
+ return ReplaceFnResult<BaseT>();
+ });
+ }
+
+private:
+ /// Internal implementation of the `replace` methods above.
+ template <typename InterfaceT, typename ReplaceFns, typename T>
+ T replaceImpl(T element, ReplaceFns &replaceFns, DenseMap<T, T> &map);
+
+ /// Replace the sub elements of the given interface.
+ template <typename InterfaceT, typename T = typename InterfaceT::ValueType>
+ T replaceSubElements(InterfaceT interface, DenseMap<T, T> &interfaceMap);
+
+ /// The set of replacement functions that map sub elements.
+ std::vector<ReplaceFn<Attribute>> attrReplacementFns;
+ std::vector<ReplaceFn<Type>> typeReplacementFns;
+
+ /// The set of cached mappings for attributes/types.
+ DenseMap<Attribute, Attribute> attrMap;
+ DenseMap<Type, Type> typeMap;
+};
//===----------------------------------------------------------------------===//
/// AttrTypeSubElementHandler
@@ -291,7 +395,7 @@ T replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
} // namespace detail
} // namespace mlir
-/// Include the definitions of the sub elemnt interfaces.
+/// Include the definitions of the sub element interfaces.
#include "mlir/IR/SubElementAttrInterfaces.h.inc"
#include "mlir/IR/SubElementTypeInterfaces.h.inc"
diff --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td
index abb5afcc93aa1..7718feb0a43c6 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.td
+++ b/mlir/include/mlir/IR/SubElementInterfaces.td
@@ -66,25 +66,14 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
llvm::function_ref<void(mlir::Type)> walkTypesFn);
/// Recursively replace all of the nested sub-attributes and sub-types using the
- /// provided map functions. Returns nullptr in the case of failure.
- }] # attrOrType # [{ replaceSubElements(
- mlir::SubElementReplFn<mlir::Attribute> replaceAttrFn,
- mlir::SubElementReplFn<mlir::Type> replaceTypeFn
- ) {
- return replaceSubElements(
- [&](Attribute attr) { return std::make_pair(replaceAttrFn(attr), WalkResult::advance()); },
- [&](Type type) { return std::make_pair(replaceTypeFn(type), WalkResult::advance()); }
- );
+ /// provided map functions. Returns nullptr in the case of failure. See
+ /// `AttrTypeReplacer` for information on the support replacement function types.
+ template <typename... ReplacementFns>
+ }] # attrOrType # [{ replaceSubElements(ReplacementFns &&... replacementFns) {
+ AttrTypeReplacer replacer;
+ (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)), ...);
+ return replacer.replace(*this);
}
- /// Recursively replace all of the nested sub-attributes and sub-types using the
- /// provided map functions. This variant allows for the map function to return an
- /// additional walk result. Returns nullptr in the case of failure.
- }] # attrOrType # [{ replaceSubElements(
- llvm::function_ref<
- std::pair<mlir::Attribute, mlir::WalkResult>(mlir::Attribute)> replaceAttrFn,
- llvm::function_ref<
- std::pair<mlir::Type, mlir::WalkResult>(mlir::Type)> replaceTypeFn
- );
}];
code extraTraitClassDeclaration = [{
/// Walk all of the held sub-attributes and sub-types.
@@ -95,18 +84,13 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
}
/// Recursively replace all of the nested sub-attributes and sub-types using the
- /// provided map functions. Returns nullptr in the case of failure.
- }] # attrOrType # [{ replaceSubElements(
- mlir::SubElementReplFn<mlir::Attribute> replaceAttrFn,
- mlir::SubElementReplFn<mlir::Type> replaceTypeFn) {
- }] # interfaceName # " interface(" # derivedValue # [{);
- return interface.replaceSubElements(replaceAttrFn, replaceTypeFn);
- }
- }] # attrOrType # [{ replaceSubElements(
- mlir::SubElementResultReplFn<mlir::Attribute> replaceAttrFn,
- mlir::SubElementResultReplFn<mlir::Type> replaceTypeFn) {
- }] # interfaceName # " interface(" # derivedValue # [{);
- return interface.replaceSubElements(replaceAttrFn, replaceTypeFn);
+ /// provided map functions. Returns nullptr in the case of failure. See
+ /// `AttrTypeReplacer` for information on the support replacement function types.
+ template <typename... ReplacementFns>
+ }] # attrOrType # [{ replaceSubElements(ReplacementFns &&... replacementFns) {
+ AttrTypeReplacer replacer;
+ (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)), ...);
+ return replacer.replace(}] # derivedValue # [{);
}
}];
code extraSharedClassDeclaration = [{
@@ -118,35 +102,6 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
void walkSubTypes(llvm::function_ref<void(mlir::Type)> walkFn) {
walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn);
}
-
- /// Recursively replace all of the nested sub-attributes using the provided
- /// map function. Returns nullptr in the case of failure.
- }] # attrOrType # [{ replaceSubElements(
- mlir::SubElementReplFn<mlir::Attribute> replaceAttrFn) {
- return replaceSubElements(
- replaceAttrFn, [](mlir::Type type) { return type; });
- }
- }] # attrOrType # [{ replaceSubElements(
- mlir::SubElementResultReplFn<mlir::Attribute> replaceAttrFn) {
- return replaceSubElements(
- replaceAttrFn,
- [](mlir::Type type) { return std::make_pair(type, WalkResult::advance()); }
- );
- }
- /// Recursively replace all of the nested sub-types using the provided map
- /// function. Returns nullptr in the case of failure.
- }] # attrOrType # [{ replaceSubElements(
- mlir::SubElementReplFn<mlir::Type> replaceTypeFn) {
- return replaceSubElements(
- [](mlir::Attribute attr) { return attr; }, replaceTypeFn);
- }
- }] # attrOrType # [{ replaceSubElements(
- mlir::SubElementResultReplFn<mlir::Type> replaceTypeFn) {
- return replaceSubElements(
- [](mlir::Attribute attr) { return std::make_pair(attr, WalkResult::advance()); },
- replaceTypeFn
- );
- }
}];
}
diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp
index fd05b9d01eea4..88aeaf191d90a 100644
--- a/mlir/lib/IR/SubElementInterfaces.cpp
+++ b/mlir/lib/IR/SubElementInterfaces.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/SubElementInterfaces.h"
+#include "mlir/IR/Operation.h"
#include "llvm/ADT/DenseSet.h"
@@ -91,116 +92,146 @@ void SubElementTypeInterface::walkSubElements(
}
//===----------------------------------------------------------------------===//
-// ReplaceSubElements
+/// AttrTypeReplacer
+//===----------------------------------------------------------------------===//
-template <typename InterfaceT, typename T, typename ReplaceSubElementFnT>
-static void updateSubElementImpl(
- T element, function_ref<std::pair<T, WalkResult>(T)> walkFn,
- DenseMap<T, T> &visited, SmallVectorImpl<T> &newElements,
- FailureOr<bool> &changed, ReplaceSubElementFnT &&replaceSubElementFn) {
- // Bail early if we failed at any point.
- if (failed(changed))
- return;
- newElements.push_back(element);
+void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceLocs,
+ bool replaceTypes) {
+ // Functor that replaces the given element if the new value is
diff erent,
+ // otherwise returns nullptr.
+ auto replaceIfDifferent = [&](auto element) {
+ auto replacement = replace(element);
+ return (replacement && replacement != element) ? replacement : nullptr;
+ };
+ // Check the attribute dictionary for replacements.
+ if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary()))
+ op->setAttrs(cast<DictionaryAttr>(newAttrs));
- // Guard against potentially null inputs. We always map null to null.
- if (!element)
+ // If we aren't updating locations or types, we're done.
+ if (!replaceTypes && !replaceLocs)
return;
- // Check for an existing mapping for this element, and walk it if we haven't
- // yet.
- T *mappedElement = &visited[element];
- if (!*mappedElement) {
- WalkResult result = WalkResult::advance();
- std::tie(*mappedElement, result) = walkFn(element);
-
- // Try walking this element.
- if (result.wasInterrupted() || !*mappedElement) {
- changed = failure();
- return;
- }
+ // Update the location.
+ if (replaceLocs) {
+ if (Attribute newLoc = replaceIfDifferent(op->getLoc()))
+ op->setLoc(cast<LocationAttr>(newLoc));
+ }
- // Handle replacing sub-elements if this element is also a container.
- if (!result.wasSkipped()) {
- if (auto interface = mappedElement->template dyn_cast<InterfaceT>()) {
- // Cache the size of the `visited` map since it may grow when calling
- // `replaceSubElementFn` and we would need to fetch again the (now
- // invalidated) reference to `mappedElement`.
- size_t visitedSize = visited.size();
- auto replacedElement = replaceSubElementFn(interface);
- if (!replacedElement) {
- changed = failure();
- return;
+ // Update the result types.
+ if (replaceTypes) {
+ for (OpResult result : op->getResults())
+ if (Type newType = replaceIfDifferent(result.getType()))
+ result.setType(newType);
+ }
+
+ // Update any nested block arguments.
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region) {
+ for (BlockArgument &arg : block.getArguments()) {
+ if (replaceLocs) {
+ if (Attribute newLoc = replaceIfDifferent(arg.getLoc()))
+ arg.setLoc(cast<LocationAttr>(newLoc));
+ }
+
+ if (replaceTypes) {
+ if (Type newType = replaceIfDifferent(arg.getType()))
+ arg.setType(newType);
}
- if (visitedSize != visited.size())
- mappedElement = &visited[element];
- *mappedElement = replacedElement;
}
}
}
+}
+
+template <typename T>
+static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
+ DenseMap<T, T> &elementMap,
+ SmallVectorImpl<T> &newElements,
+ FailureOr<bool> &changed) {
+ // Bail early if we failed at any point.
+ if (failed(changed))
+ return;
+
+ // Guard against potentially null inputs. We always map null to null.
+ if (!element) {
+ newElements.push_back(nullptr);
+ return;
+ }
- // Update to the mapped element.
- if (*mappedElement != element) {
- newElements.back() = *mappedElement;
- changed = true;
+ // Replace the element.
+ if (T result = replacer.replace(element)) {
+ newElements.push_back(result);
+ if (result != element)
+ changed = true;
+ } else {
+ changed = failure();
}
}
-template <typename InterfaceT>
-static typename InterfaceT::ValueType
-replaceSubElementsImpl(InterfaceT interface,
- SubElementResultReplFn<Attribute> walkAttrsFn,
- SubElementResultReplFn<Type> walkTypesFn,
- DenseMap<Attribute, Attribute> &visitedAttrs,
- DenseMap<Type, Type> &visitedTypes) {
+template <typename InterfaceT, typename T>
+T AttrTypeReplacer::replaceSubElements(InterfaceT interface,
+ DenseMap<T, T> &interfaceMap) {
// Walk the current sub-elements, replacing them as necessary.
SmallVector<Attribute, 16> newAttrs;
SmallVector<Type, 16> newTypes;
FailureOr<bool> changed = false;
- auto replaceSubElementFn = [&](auto subInterface) {
- return replaceSubElementsImpl(subInterface, walkAttrsFn, walkTypesFn,
- visitedAttrs, visitedTypes);
- };
interface.walkImmediateSubElements(
[&](Attribute element) {
- updateSubElementImpl<SubElementAttrInterface>(
- element, walkAttrsFn, visitedAttrs, newAttrs, changed,
- replaceSubElementFn);
+ updateSubElementImpl(element, *this, attrMap, newAttrs, changed);
},
[&](Type element) {
- updateSubElementImpl<SubElementTypeInterface>(
- element, walkTypesFn, visitedTypes, newTypes, changed,
- replaceSubElementFn);
+ updateSubElementImpl(element, *this, typeMap, newTypes, changed);
});
if (failed(changed))
- return {};
+ return nullptr;
- // If the sub-elements didn't change, just return the original value.
- if (!*changed)
- return interface;
+ // If any sub-elements changed, use the new elements during the replacement.
+ T result = interface;
+ if (*changed)
+ result = interface.replaceImmediateSubElements(newAttrs, newTypes);
+ return result;
+}
+
+/// Shared implementation of replacing a given attribute or type element.
+template <typename InterfaceT, typename ReplaceFns, typename T>
+T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns,
+ DenseMap<T, T> &map) {
+ auto [it, inserted] = map.try_emplace(element, element);
+ if (!inserted)
+ return it->second;
+
+ T result = element;
+ WalkResult walkResult = WalkResult::advance();
+ for (auto &replaceFn : llvm::reverse(replaceFns)) {
+ if (Optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) {
+ std::tie(result, walkResult) = *newRes;
+ break;
+ }
+ }
+
+ // If an error occurred, return nullptr to indicate failure.
+ if (walkResult.wasInterrupted() || !result)
+ return map[element] = nullptr;
+
+ // Handle replacing sub-elements if this element is also a container.
+ if (!walkResult.wasSkipped()) {
+ if (auto interface = dyn_cast<InterfaceT>(result)) {
+ // Replace the sub elements of this element, bailing if we fail.
+ if (!(result = replaceSubElements(interface, map)))
+ return map[element] = nullptr;
+ }
+ }
- // Use the new elements during the replacement.
- return interface.replaceImmediateSubElements(newAttrs, newTypes);
+ return map[element] = result;
}
-Attribute SubElementAttrInterface::replaceSubElements(
- SubElementResultReplFn<Attribute> replaceAttrFn,
- SubElementResultReplFn<Type> replaceTypeFn) {
- assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
- DenseMap<Attribute, Attribute> visitedAttrs;
- DenseMap<Type, Type> visitedTypes;
- return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn,
- visitedAttrs, visitedTypes);
+Attribute AttrTypeReplacer::replace(Attribute attr) {
+ return replaceImpl<SubElementAttrInterface>(attr, attrReplacementFns,
+ attrMap);
}
-Type SubElementTypeInterface::replaceSubElements(
- SubElementResultReplFn<Attribute> replaceAttrFn,
- SubElementResultReplFn<Type> replaceTypeFn) {
- assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
- DenseMap<Attribute, Attribute> visitedAttrs;
- DenseMap<Type, Type> visitedTypes;
- return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn,
- visitedAttrs, visitedTypes);
+Type AttrTypeReplacer::replace(Type type) {
+ return replaceImpl<SubElementTypeInterface>(type, typeReplacementFns,
+ typeMap);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 1acaea7c0b6df..2874ddcbbc97d 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -853,40 +853,31 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
SymbolRefAttr oldAttr = scope.symbol;
SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
-
- auto walkFn = [&](Operation *op) -> Optional<WalkResult> {
- auto remapAttrFn =
- [&](Attribute attr) -> std::pair<Attribute, WalkResult> {
- // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
- // want to accidentally replace an inner reference.
- if (attr == oldAttr)
- return {newAttr, WalkResult::skip()};
- // Handle prefix matches.
- if (SymbolRefAttr symRef = attr.dyn_cast<SymbolRefAttr>()) {
- if (isReferencePrefixOf(oldAttr, symRef)) {
+ AttrTypeReplacer replacer;
+ replacer.addReplacement(
+ [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
+ // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
+ // want to accidentally replace an inner reference.
+ if (attr == oldAttr)
+ return {newAttr, WalkResult::skip()};
+ // Handle prefix matches.
+ if (isReferencePrefixOf(oldAttr, attr)) {
auto oldNestedRefs = oldAttr.getNestedReferences();
- auto nestedRefs = symRef.getNestedReferences();
+ auto nestedRefs = attr.getNestedReferences();
if (oldNestedRefs.empty())
return {SymbolRefAttr::get(newSymbol, nestedRefs),
WalkResult::skip()};
auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
- return {
- SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs),
- WalkResult::skip()};
+ return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
+ WalkResult::skip()};
}
return {attr, WalkResult::skip()};
- }
- return {attr, WalkResult::advance()};
- };
- // Generate a new attribute dictionary by replacing references to the old
- // symbol.
- auto newDict = op->getAttrDictionary().replaceSubElements(remapAttrFn);
- if (!newDict)
- return WalkResult::interrupt();
-
- op->setAttrs(newDict.template cast<DictionaryAttr>());
+ });
+
+ auto walkFn = [&](Operation *op) -> Optional<WalkResult> {
+ replacer.replaceElementsIn(op);
return WalkResult::advance();
};
if (!scope.walkSymbolTable(walkFn))
More information about the Mlir-commits
mailing list