[Mlir-commits] [mlir] 01eedbc - [mlir] Refactor SubElementInterface replace support
River Riddle
llvmlistbot at llvm.org
Tue Jul 26 14:51:50 PDT 2022
Author: River Riddle
Date: 2022-07-26T14:51:22-07:00
New Revision: 01eedbc7c14859c273bbd98693c67f35c59e8d85
URL: https://github.com/llvm/llvm-project/commit/01eedbc7c14859c273bbd98693c67f35c59e8d85
DIFF: https://github.com/llvm/llvm-project/commit/01eedbc7c14859c273bbd98693c67f35c59e8d85.diff
LOG: [mlir] Refactor SubElementInterface replace support
The current support was essentially the amount necessary
to support replacing SymbolRefAttrs, but suffers from various
deficiencies (both ergonomic and functional):
* Replace crashes if unsupported
This makes it really hard to use safely, given that you don't know
if you are going to crash or not when using it.
* Types aren't supported
This seems like a simple missed addition when the attribute replacement
support was originally added.
* The ergonomics are weird
It currently uses an index based replacement, which makes the implementations
quite clunky.
This commit refactors support to be a bit more ergonomic, and also
adds support for types in the process. This was also a great oppurtunity
to greatly simplify how replacement is done in the symbol table.
Fixes #56355
Differential Revision: https://reviews.llvm.org/D130589
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/SubElementInterfaces.td
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/SubElementInterfaces.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/Dialect/Test/TestAttributes.cpp
mlir/test/lib/Dialect/Test/TestTypes.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index e415061768fe4..ce69a2e583aeb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -116,6 +116,8 @@ class LLVMArrayType
void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const;
+ Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const;
};
//===----------------------------------------------------------------------===//
@@ -177,6 +179,8 @@ class LLVMFunctionType : public Type::TypeBase<LLVMFunctionType, Type,
void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const;
+ Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const;
};
//===----------------------------------------------------------------------===//
@@ -244,6 +248,8 @@ class LLVMPointerType
void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const;
+ Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const;
};
//===----------------------------------------------------------------------===//
@@ -375,6 +381,8 @@ class LLVMStructType
void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const;
+ Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const;
};
//===----------------------------------------------------------------------===//
@@ -408,7 +416,7 @@ class LLVMFixedVectorType
Type getElementType() const;
/// Returns the number of elements in the fixed vector.
- unsigned getNumElements();
+ unsigned getNumElements() const;
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
@@ -416,6 +424,8 @@ class LLVMFixedVectorType
void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const;
+ Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const;
};
//===----------------------------------------------------------------------===//
@@ -450,7 +460,7 @@ class LLVMScalableVectorType
/// Returns the scaling factor of the number of elements in the vector. The
/// vector contains at least the resulting number of elements, or any non-zero
/// multiple of this number.
- unsigned getMinNumElements();
+ unsigned getMinNumElements() const;
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
@@ -458,6 +468,8 @@ class LLVMScalableVectorType
void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const;
+ Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 503c0209eafa1..67fa0a31b5670 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -72,8 +72,7 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
//===----------------------------------------------------------------------===//
def Builtin_ArrayAttr : Builtin_Attr<"Array", [
- DeclareAttrInterfaceMethods<SubElementAttrInterface,
- ["replaceImmediateSubAttribute"]>
+ DeclareAttrInterfaceMethods<SubElementAttrInterface>
]> {
let summary = "A collection of other Attribute values";
let description = [{
@@ -425,8 +424,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
//===----------------------------------------------------------------------===//
def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
- DeclareAttrInterfaceMethods<SubElementAttrInterface,
- ["replaceImmediateSubAttribute"]>
+ DeclareAttrInterfaceMethods<SubElementAttrInterface>
]> {
let summary = "An dictionary of named Attribute values";
let description = [{
@@ -1046,7 +1044,9 @@ def Builtin_StringAttr : Builtin_Attr<"String"> {
// SymbolRefAttr
//===----------------------------------------------------------------------===//
-def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
+def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", [
+ DeclareAttrInterfaceMethods<SubElementAttrInterface>
+ ]> {
let summary = "An Attribute containing a symbolic reference to an Operation";
let description = [{
Syntax:
diff --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td
index 9ee90513cdda3..f1aede639cf09 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.td
+++ b/mlir/include/mlir/IR/SubElementInterfaces.td
@@ -21,7 +21,8 @@ include "mlir/IR/OpBase.td"
// SubElementInterfaceBase
//===----------------------------------------------------------------------===//
-class SubElementInterfaceBase<string interfaceName, string derivedValue> {
+class SubElementInterfaceBase<string interfaceName, string attrOrType,
+ string derivedValue> {
string cppNamespace = "::mlir";
list<InterfaceMethod> methods = [
@@ -35,52 +36,78 @@ class SubElementInterfaceBase<string interfaceName, string derivedValue> {
>,
InterfaceMethod<
/*desc=*/[{
- Replace the attributes identified by the indices with the corresponding
- value. The index is derived from the order of the attributes returned by
- the attribute callback of `walkImmediateSubElements`. An index of 0 would
- replace the very first attribute given by `walkImmediateSubElements`.
- The new instance with the values replaced is returned.
- }], cppNamespace # "::" # interfaceName, "replaceImmediateSubAttribute",
- (ins "::llvm::ArrayRef<std::pair<size_t, ::mlir::Attribute>>":$replacements),
- [{}],
- /*defaultImplementation=*/[{
- llvm_unreachable("Attribute or Type does not support replacing attributes");
- }]
- >,
+ Replace the immediately nested sub-attributes and sub-types with those provided.
+ The order of the provided elements is derived from the order of the elements
+ returned by the callbacks of `walkImmediateSubElements`. The element at index 0
+ would replace the very first attribute given by `walkImmediateSubElements`.
+ On success, the new instance with the values replaced is returned. If replacement
+ fails, nullptr is returned.
+ }], attrOrType, "replaceImmediateSubElements", (ins
+ "::llvm::ArrayRef<::mlir::Attribute>":$replAttrs,
+ "::llvm::ArrayRef<::mlir::Type>":$replTypes
+ )>,
];
code extraClassDeclaration = [{
- /// Walk all of the held sub-attributes.
- void walkSubAttrs(llvm::function_ref<void(mlir::Attribute)> walkFn) {
- walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {});
- }
-
- /// Walk all of the held sub-types.
- void walkSubTypes(llvm::function_ref<void(mlir::Type)> walkFn) {
- walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn);
- }
-
/// Walk all of the held sub-attributes and sub-types.
void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
llvm::function_ref<void(mlir::Type)> walkTypesFn);
- }];
+ /// Recursively replace all of the nested sub-attributes and sub-types using the
+ /// provided map functions. Returns nullptr in the case of failure.
+ }] # attrOrType # [{ replaceSubElements(
+ llvm::function_ref<mlir::Attribute(mlir::Attribute)> replaceAttrFn,
+ llvm::function_ref<mlir::Type(mlir::Type)> replaceTypeFn
+ );
+ }];
code extraTraitClassDeclaration = [{
+ /// Walk all of the held sub-attributes and sub-types.
+ void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
+ llvm::function_ref<void(mlir::Type)> walkTypesFn) {
+ }] # interfaceName # " interface(" # derivedValue # [{);
+ interface.walkSubElements(walkAttrsFn, walkTypesFn);
+ }
+
+ /// Recursively replace all of the nested sub-attributes and sub-types using the
+ /// provided map functions. Returns nullptr in the case of failure.
+ }] # attrOrType # [{ replaceSubElements(
+ llvm::function_ref<mlir::Attribute(mlir::Attribute)> replaceAttrFn,
+ llvm::function_ref<mlir::Type(mlir::Type)> replaceTypeFn) {
+ }] # interfaceName # " interface(" # derivedValue # [{);
+ return interface.replaceSubElements(replaceAttrFn, replaceTypeFn);
+ }
+
+ /// Recursively replace all of the nested sub-attributes and sub-types using the
+ /// provided map functions. Returns nullptr in the case of failure.
+ }] # attrOrType # [{ replaceImmediateSubElements(
+ llvm::ArrayRef<mlir::Attribute> replAttrs,
+ llvm::function_ref<mlir::Type(mlir::Type)> replTypes) {
+ return nullptr;
+ }
+ }];
+ code extraSharedClassDeclaration = [{
/// Walk all of the held sub-attributes.
void walkSubAttrs(llvm::function_ref<void(mlir::Attribute)> walkFn) {
walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {});
}
-
/// Walk all of the held sub-types.
void walkSubTypes(llvm::function_ref<void(mlir::Type)> walkFn) {
walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn);
}
-
- /// Walk all of the held sub-attributes and sub-types.
- void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
- llvm::function_ref<void(mlir::Type)> walkTypesFn) {
- }] # interfaceName # " interface(" # derivedValue # [{);
- interface.walkSubElements(walkAttrsFn, walkTypesFn);
+
+ /// Recursively replace all of the nested sub-attributes using the provided
+ /// map function. Returns nullptr in the case of failure.
+ }] # attrOrType # [{ replaceSubElements(
+ llvm::function_ref<mlir::Attribute(mlir::Attribute)> replaceAttrFn) {
+ return replaceSubElements(
+ replaceAttrFn, [](mlir::Type type) { return type; });
+ }
+ /// Recursively replace all of the nested sub-types using the provided map
+ /// function. Returns nullptr in the case of failure.
+ }] # attrOrType # [{ replaceSubElements(
+ llvm::function_ref<mlir::Type(mlir::Type)> replaceTypeFn) {
+ return replaceSubElements(
+ [](mlir::Attribute attr) { return attr; }, replaceTypeFn);
}
}];
}
@@ -91,7 +118,8 @@ class SubElementInterfaceBase<string interfaceName, string derivedValue> {
def SubElementAttrInterface
: AttrInterface<"SubElementAttrInterface">,
- SubElementInterfaceBase<"SubElementAttrInterface", "$_attr"> {
+ SubElementInterfaceBase<"SubElementAttrInterface", "::mlir::Attribute",
+ "$_attr"> {
let description = [{
An interface used to query and manipulate sub-elements, such as sub-types
and sub-attributes of a composite attribute.
@@ -104,7 +132,8 @@ def SubElementAttrInterface
def SubElementTypeInterface
: TypeInterface<"SubElementTypeInterface">,
- SubElementInterfaceBase<"SubElementTypeInterface", "$_type"> {
+ SubElementInterfaceBase<"SubElementTypeInterface", "::mlir::Type",
+ "$_type"> {
let description = [{
An interface used to query and manipulate sub-elements, such as sub-types
and sub-attributes of a composite type.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 49d2d8d24963b..3f39a6e07748b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -92,6 +92,11 @@ void LLVMArrayType::walkImmediateSubElements(
walkTypesFn(getElementType());
}
+Type LLVMArrayType::replaceImmediateSubElements(
+ ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+ return get(replTypes.front(), getNumElements());
+}
+
//===----------------------------------------------------------------------===//
// Function type.
//===----------------------------------------------------------------------===//
@@ -166,6 +171,11 @@ void LLVMFunctionType::walkImmediateSubElements(
walkTypesFn(type);
}
+Type LLVMFunctionType::replaceImmediateSubElements(
+ ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+ return get(replTypes.front(), replTypes.drop_front(), isVarArg());
+}
+
//===----------------------------------------------------------------------===//
// Pointer type.
//===----------------------------------------------------------------------===//
@@ -374,6 +384,11 @@ void LLVMPointerType::walkImmediateSubElements(
walkTypesFn(getElementType());
}
+Type LLVMPointerType::replaceImmediateSubElements(
+ ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+ return get(replTypes.front(), getAddressSpace());
+}
+
//===----------------------------------------------------------------------===//
// Struct type.
//===----------------------------------------------------------------------===//
@@ -617,6 +632,13 @@ void LLVMStructType::walkImmediateSubElements(
walkTypesFn(type);
}
+Type LLVMStructType::replaceImmediateSubElements(
+ ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+ // TODO: It's not clear how we support replacing sub-elements of mutable
+ // types.
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// Vector types.
//===----------------------------------------------------------------------===//
@@ -653,7 +675,7 @@ Type LLVMFixedVectorType::getElementType() const {
return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
}
-unsigned LLVMFixedVectorType::getNumElements() {
+unsigned LLVMFixedVectorType::getNumElements() const {
return getImpl()->numElements;
}
@@ -674,6 +696,11 @@ void LLVMFixedVectorType::walkImmediateSubElements(
walkTypesFn(getElementType());
}
+Type LLVMFixedVectorType::replaceImmediateSubElements(
+ ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+ return get(replTypes[0], getNumElements());
+}
+
//===----------------------------------------------------------------------===//
// LLVMScalableVectorType.
//===----------------------------------------------------------------------===//
@@ -696,7 +723,7 @@ Type LLVMScalableVectorType::getElementType() const {
return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
}
-unsigned LLVMScalableVectorType::getMinNumElements() {
+unsigned LLVMScalableVectorType::getMinNumElements() const {
return getImpl()->numElements;
}
@@ -720,6 +747,11 @@ void LLVMScalableVectorType::walkImmediateSubElements(
walkTypesFn(getElementType());
}
+Type LLVMScalableVectorType::replaceImmediateSubElements(
+ ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+ return get(replTypes[0], getMinNumElements());
+}
+
//===----------------------------------------------------------------------===//
// Utility functions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 218622ba9d73b..5b282bf868b59 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -54,13 +54,10 @@ void ArrayAttr::walkImmediateSubElements(
walkAttrsFn(attr);
}
-SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute(
- ArrayRef<std::pair<size_t, Attribute>> replacements) const {
- std::vector<Attribute> vector = getValue().vec();
- for (auto &it : replacements) {
- vector[it.first] = it.second;
- }
- return get(getContext(), vector);
+Attribute
+ArrayAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const {
+ return get(getContext(), replAttrs);
}
//===----------------------------------------------------------------------===//
@@ -227,11 +224,12 @@ void DictionaryAttr::walkImmediateSubElements(
walkAttrsFn(attr.getValue());
}
-SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute(
- ArrayRef<std::pair<size_t, Attribute>> replacements) const {
+Attribute
+DictionaryAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const {
std::vector<NamedAttribute> vec = getValue().vec();
- for (auto &it : replacements)
- vec[it.first].setValue(it.second);
+ for (auto &it : llvm::enumerate(replAttrs))
+ vec[it.index()].setValue(it.value());
// The above only modifies the mapped value, but not the key, and therefore
// not the order of the elements. It remains sorted
@@ -326,6 +324,24 @@ StringAttr SymbolRefAttr::getLeafReference() const {
return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
}
+void SymbolRefAttr::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ walkAttrsFn(getRootReference());
+ for (FlatSymbolRefAttr ref : getNestedReferences())
+ walkAttrsFn(ref);
+}
+
+Attribute
+SymbolRefAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const {
+ ArrayRef<Attribute> rawNestedRefs = replAttrs.drop_front();
+ ArrayRef<FlatSymbolRefAttr> nestedRefs(
+ static_cast<const FlatSymbolRefAttr *>(rawNestedRefs.data()),
+ rawNestedRefs.size());
+ return get(replAttrs[0].cast<StringAttr>(), nestedRefs);
+}
+
//===----------------------------------------------------------------------===//
// IntegerAttr
//===----------------------------------------------------------------------===//
@@ -1711,3 +1727,9 @@ void TypeAttr::walkImmediateSubElements(
function_ref<void(Type)> walkTypesFn) const {
walkTypesFn(getValue());
}
+
+Attribute
+TypeAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const {
+ return get(replTypes[0]);
+}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 0d8ede99c92e6..5459841c8391a 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -193,6 +193,13 @@ void FunctionType::walkImmediateSubElements(
walkTypesFn(type);
}
+Type FunctionType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const {
+ unsigned numInputs = getNumInputs();
+ return get(getContext(), replTypes.take_front(numInputs),
+ replTypes.drop_front(numInputs));
+}
+
//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//
@@ -256,6 +263,11 @@ void VectorType::walkImmediateSubElements(
walkTypesFn(getElementType());
}
+Type VectorType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const {
+ return get(getShape(), replTypes.front(), getNumScalableDims());
+}
+
VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return VectorType::get(shape.value_or(getShape()), elementType,
@@ -338,6 +350,12 @@ void RankedTensorType::walkImmediateSubElements(
walkAttrsFn(encoding);
}
+Type RankedTensorType::replaceImmediateSubElements(
+ ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+ return get(getShape(), replTypes.front(),
+ replAttrs.empty() ? Attribute() : replAttrs.back());
+}
+
//===----------------------------------------------------------------------===//
// UnrankedTensorType
//===----------------------------------------------------------------------===//
@@ -354,6 +372,11 @@ void UnrankedTensorType::walkImmediateSubElements(
walkTypesFn(getElementType());
}
+Type UnrankedTensorType::replaceImmediateSubElements(
+ ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+ return get(replTypes.front());
+}
+
//===----------------------------------------------------------------------===//
// BaseMemRefType
//===----------------------------------------------------------------------===//
@@ -663,6 +686,15 @@ void MemRefType::walkImmediateSubElements(
walkAttrsFn(getMemorySpace());
}
+Type MemRefType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const {
+ bool hasLayout = replAttrs.size() > 1;
+ return get(getShape(), replTypes[0],
+ hasLayout ? replAttrs[0].dyn_cast<MemRefLayoutAttrInterface>()
+ : MemRefLayoutAttrInterface(),
+ hasLayout ? replAttrs[1] : replAttrs[0]);
+}
+
//===----------------------------------------------------------------------===//
// UnrankedMemRefType
//===----------------------------------------------------------------------===//
@@ -829,6 +861,11 @@ void UnrankedMemRefType::walkImmediateSubElements(
walkAttrsFn(getMemorySpace());
}
+Type UnrankedMemRefType::replaceImmediateSubElements(
+ ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+ return get(replTypes.front(), replAttrs.front());
+}
+
//===----------------------------------------------------------------------===//
/// TupleType
//===----------------------------------------------------------------------===//
@@ -859,6 +896,11 @@ void TupleType::walkImmediateSubElements(
walkTypesFn(type);
}
+Type TupleType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const {
+ return get(getContext(), replTypes);
+}
+
//===----------------------------------------------------------------------===//
// Type Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp
index 4059b99b5db24..f8d47083f11c4 100644
--- a/mlir/lib/IR/SubElementInterfaces.cpp
+++ b/mlir/lib/IR/SubElementInterfaces.cpp
@@ -12,6 +12,13 @@
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// SubElementInterface
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// WalkSubElements
+
template <typename InterfaceT>
static void walkSubElementsImpl(InterfaceT interface,
function_ref<void(Attribute)> walkAttrsFn,
@@ -83,6 +90,121 @@ void SubElementTypeInterface::walkSubElements(
visitedTypes);
}
+//===----------------------------------------------------------------------===//
+// ReplaceSubElements
+
+/// Return if the given element is mutable.
+static bool isMutable(Attribute attr) {
+ return attr.hasTrait<AttributeTrait::IsMutable>();
+}
+static bool isMutable(Type type) {
+ return type.hasTrait<TypeTrait::IsMutable>();
+}
+
+template <typename InterfaceT, typename T, typename ReplaceSubElementFnT>
+static void updateSubElementImpl(T element, function_ref<T(T)> walkFn,
+ DenseMap<T, T> &visited,
+ SmallVectorImpl<T> &newElements,
+ FailureOr<bool> &changed,
+ ReplaceSubElementFnT &&replaceSubElementFn) {
+ // Bail early if we failed at any point.
+ if (failed(changed))
+ return;
+ newElements.push_back(element);
+
+ // Guard against potentially null inputs. We always map null to null.
+ if (!element)
+ return;
+
+ // Check for an existing mapping for this element, and walk it if we haven't
+ // yet.
+ T &mappedElement = visited[element];
+ if (!mappedElement) {
+ // Try walking this element.
+ if (!(mappedElement = walkFn(element))) {
+ changed = failure();
+ return;
+ }
+
+ // Handle replacing sub-elements if this element is also a container.
+ if (auto interface = mappedElement.template dyn_cast<InterfaceT>()) {
+ if (!(mappedElement = replaceSubElementFn(interface))) {
+ changed = failure();
+ return;
+ }
+ }
+ }
+
+ // Update to the mapped element.
+ if (mappedElement != element) {
+ newElements.back() = mappedElement;
+ changed = true;
+ }
+}
+
+template <typename InterfaceT>
+static typename InterfaceT::ValueType
+replaceSubElementsImpl(InterfaceT interface,
+ function_ref<Attribute(Attribute)> walkAttrsFn,
+ function_ref<Type(Type)> walkTypesFn,
+ DenseMap<Attribute, Attribute> &visitedAttrs,
+ DenseMap<Type, Type> &visitedTypes) {
+ // Walk the current sub-elements, replacing them as necessary.
+ SmallVector<Attribute, 16> newAttrs;
+ SmallVector<Type, 16> newTypes;
+ FailureOr<bool> changed = false;
+ auto replaceSubElementFn = [&](auto subInterface) {
+ return replaceSubElementsImpl(subInterface, walkAttrsFn, walkTypesFn,
+ visitedAttrs, visitedTypes);
+ };
+ interface.walkImmediateSubElements(
+ [&](Attribute element) {
+ updateSubElementImpl<SubElementAttrInterface>(
+ element, walkAttrsFn, visitedAttrs, newAttrs, changed,
+ replaceSubElementFn);
+ },
+ [&](Type element) {
+ updateSubElementImpl<SubElementTypeInterface>(
+ element, walkTypesFn, visitedTypes, newTypes, changed,
+ replaceSubElementFn);
+ });
+ if (failed(changed))
+ return {};
+
+ // If the sub-elements didn't change, just return the original value.
+ if (!*changed)
+ return interface;
+
+ // If this element is mutable, we don't support changing its sub elements, the
+ // sub element walk doesn't give us a valid ordering for what we need here. If
+ // we want to support mutable elements, we'll need something more.
+ if (isMutable(interface))
+ return {};
+
+ // Use the new elements during the replacement.
+ return interface.replaceImmediateSubElements(newAttrs, newTypes);
+}
+
+Attribute SubElementAttrInterface::replaceSubElements(
+ function_ref<Attribute(Attribute)> replaceAttrFn,
+ function_ref<Type(Type)> replaceTypeFn) {
+ assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
+ DenseMap<Attribute, Attribute> visitedAttrs;
+ DenseMap<Type, Type> visitedTypes;
+ return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn,
+ visitedAttrs, visitedTypes);
+}
+
+Type SubElementTypeInterface::replaceSubElements(
+ function_ref<Attribute(Attribute)> replaceAttrFn,
+ function_ref<Type(Type)> replaceTypeFn) {
+ assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
+ DenseMap<Attribute, Attribute> visitedAttrs;
+ DenseMap<Type, Type> visitedTypes;
+ return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn,
+ visitedAttrs, visitedTypes);
+}
+
//===----------------------------------------------------------------------===//
// SubElementInterface Tablegen definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index b55487c84d910..fb56d91f68a6c 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -97,6 +97,18 @@ walkSymbolTable(MutableArrayRef<Region> regions,
return WalkResult::advance();
}
+/// Walk all of the operations nested under, and including, the given operation,
+/// without traversing into any nested symbol tables. Stops walking if the
+/// result of the callback is anything other than `WalkResult::advance`.
+static Optional<WalkResult>
+walkSymbolTable(Operation *op,
+ function_ref<Optional<WalkResult>(Operation *)> callback) {
+ Optional<WalkResult> result = callback(op);
+ if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>())
+ return result;
+ return walkSymbolTable(op->getRegions(), callback);
+}
+
//===----------------------------------------------------------------------===//
// SymbolTable
//===----------------------------------------------------------------------===//
@@ -465,21 +477,11 @@ LogicalResult detail::verifySymbol(Operation *op) {
//===----------------------------------------------------------------------===//
/// Walk all of the symbol references within the given operation, invoking the
-/// provided callback for each found use. The callbacks takes as arguments: the
-/// use of the symbol, and the nested access chain to the attribute within the
-/// operation dictionary. An access chain is a set of indices into nested
-/// container attributes. For example, a symbol use in an attribute dictionary
-/// that looks like the following:
-///
-/// {use = [{other_attr, @symbol}]}
-///
-/// May have the following access chain:
-///
-/// [0, 0, 1]
-///
-static WalkResult walkSymbolRefs(
- Operation *op,
- function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
+/// provided callback for each found use. The callbacks takes the use of the
+/// symbol.
+static WalkResult
+walkSymbolRefs(Operation *op,
+ function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
// Check to see if the operation has any attributes.
DictionaryAttr attrDict = op->getAttrDictionary();
if (attrDict.empty())
@@ -507,20 +509,19 @@ static WalkResult walkSymbolRefs(
WorklistItem &worklistItem) -> WalkResult {
for (Attribute attr :
llvm::drop_begin(worklistItem.immediateSubElements, index)) {
- /// Check for a nested container attribute, these will also need to be
- /// walked.
- if (auto interface = attr.dyn_cast<SubElementAttrInterface>()) {
- attrWorklist.emplace_back(interface);
- curAccessChain.push_back(-1);
- return WalkResult::advance();
- }
-
// Invoke the provided callback if we find a symbol use and check for a
// requested interrupt.
- if (auto symbolRef = attr.dyn_cast<SymbolRefAttr>())
- if (callback({op, symbolRef}, curAccessChain).wasInterrupted())
+ if (auto symbolRef = attr.dyn_cast<SymbolRefAttr>()) {
+ if (callback({op, symbolRef}).wasInterrupted())
return WalkResult::interrupt();
+ /// Check for a nested container attribute, these will also need to be
+ /// walked.
+ } else if (auto interface = attr.dyn_cast<SubElementAttrInterface>()) {
+ attrWorklist.emplace_back(interface);
+ curAccessChain.push_back(-1);
+ return WalkResult::advance();
+ }
// Make sure to keep the index counter in sync.
++index;
}
@@ -546,9 +547,9 @@ static WalkResult walkSymbolRefs(
/// Walk all of the uses, for any symbol, that are nested within the given
/// regions, invoking the provided callback for each. This does not traverse
/// into any nested symbol tables.
-static Optional<WalkResult> walkSymbolUses(
- MutableArrayRef<Region> regions,
- function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
+static Optional<WalkResult>
+walkSymbolUses(MutableArrayRef<Region> regions,
+ function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
return walkSymbolTable(regions, [&](Operation *op) -> Optional<WalkResult> {
// Check that this isn't a potentially unknown symbol table.
if (isPotentiallyUnknownSymbolTable(op))
@@ -560,9 +561,9 @@ static Optional<WalkResult> walkSymbolUses(
/// Walk all of the uses, for any symbol, that are nested within the given
/// operation 'from', invoking the provided callback for each. This does not
/// traverse into any nested symbol tables.
-static Optional<WalkResult> walkSymbolUses(
- Operation *from,
- function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
+static Optional<WalkResult>
+walkSymbolUses(Operation *from,
+ function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
// If this operation has regions, and it, as well as its dialect, isn't
// registered then conservatively fail. The operation may define a
// symbol table, so we can't opaquely know if we should traverse to find
@@ -608,11 +609,20 @@ struct SymbolScope {
typename llvm::function_traits<CallbackT>::result_t,
void>::value> * = nullptr>
Optional<WalkResult> walk(CallbackT cback) {
- return walk([=](SymbolTable::SymbolUse use, ArrayRef<int>) {
+ return walk([=](SymbolTable::SymbolUse use) {
return cback(use), WalkResult::advance();
});
}
+ /// Walk all of the operations nested under the current scope without
+ /// traversing into any nested symbol tables.
+ template <typename CallbackT>
+ Optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
+ if (Region *region = limit.dyn_cast<Region *>())
+ return ::walkSymbolTable(*region, cback);
+ return ::walkSymbolTable(limit.get<Operation *>(), cback);
+ }
+
/// The representation of the symbol within this scope.
SymbolRefAttr symbol;
@@ -723,7 +733,7 @@ static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
template <typename FromT>
static Optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
std::vector<SymbolTable::SymbolUse> uses;
- auto walkFn = [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
+ auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
uses.push_back(symbolUse);
return WalkResult::advance();
};
@@ -792,7 +802,7 @@ template <typename SymbolT, typename IRUnitT>
static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
// Walk all of the symbol uses looking for a reference to 'symbol'.
- if (scope.walk([&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
+ if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
? WalkResult::interrupt()
: WalkResult::advance();
@@ -822,50 +832,6 @@ bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
//===----------------------------------------------------------------------===//
// SymbolTable::replaceAllSymbolUses
-/// Rebuild the given attribute container after replacing all references to a
-/// symbol with the updated attribute in 'accesses'.
-static SubElementAttrInterface rebuildAttrAfterRAUW(
- SubElementAttrInterface container,
- ArrayRef<std::pair<SmallVector<int, 1>, SymbolRefAttr>> accesses,
- unsigned depth) {
- // Given a range of Attributes, update the ones referred to by the given
- // access chains to point to the new symbol attribute.
-
- SmallVector<std::pair<size_t, Attribute>> replacements;
-
- SmallVector<Attribute> subElements;
- container.walkImmediateSubElements(
- [&](Attribute attribute) { subElements.push_back(attribute); },
- [](Type) {});
- for (unsigned i = 0, e = accesses.size(); i != e;) {
- ArrayRef<int> access = accesses[i].first;
-
- // Check to see if this is a leaf access, i.e. a SymbolRef.
- if (access.size() == depth + 1) {
- replacements.emplace_back(access.back(), accesses[i].second);
- ++i;
- continue;
- }
-
- // Otherwise, this is a container. Collect all of the accesses for this
- // index and recurse. The recursion here is bounded by the size of the
- // largest access array.
- auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) {
- ArrayRef<int> nextAccess = it.first;
- return nextAccess.size() > depth + 1 &&
- nextAccess[depth] == access[depth];
- });
- auto result = rebuildAttrAfterRAUW(subElements[access[depth]],
- nestedAccesses, depth + 1);
- replacements.emplace_back(access[depth], result);
-
- // Skip over all of the accesses that refer to the nested container.
- i += nestedAccesses.size();
- }
-
- return container.replaceImmediateSubAttribute(replacements);
-}
-
/// Generates a new symbol reference attribute with a new leaf reference.
static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
FlatSymbolRefAttr newLeafAttr) {
@@ -880,77 +846,43 @@ static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
template <typename SymbolT, typename IRUnitT>
static LogicalResult
replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
- // A collection of operations along with their new attribute dictionary.
- std::vector<std::pair<Operation *, DictionaryAttr>> updatedAttrDicts;
-
- // The current operation being processed.
- Operation *curOp = nullptr;
-
- // The set of access chains into the attribute dictionary of the current
- // operation, as well as the replacement attribute to use.
- SmallVector<std::pair<SmallVector<int, 1>, SymbolRefAttr>, 1> accessChains;
-
- // Generate a new attribute dictionary for the current operation by replacing
- // references to the old symbol.
- auto generateNewAttrDict = [&] {
- auto oldDict = curOp->getAttrDictionary();
- auto newDict = rebuildAttrAfterRAUW(oldDict, accessChains, /*depth=*/0);
- return newDict.cast<DictionaryAttr>();
- };
-
// Generate a new attribute to replace the given attribute.
FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
+ SymbolRefAttr oldAttr = scope.symbol;
SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
- auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
- ArrayRef<int> accessChain) {
- SymbolRefAttr useRef = symbolUse.getSymbolRef();
- if (!isReferencePrefixOf(scope.symbol, useRef))
- return WalkResult::advance();
- // If we have a valid match, check to see if this is a proper
- // subreference. If it is, then we will need to generate a
diff erent new
- // attribute specifically for this use.
- SymbolRefAttr replacementRef = newAttr;
- if (useRef != scope.symbol) {
- if (scope.symbol.isa<FlatSymbolRefAttr>()) {
- replacementRef =
- SymbolRefAttr::get(newSymbol, useRef.getNestedReferences());
- } else {
- auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
- nestedRefs[scope.symbol.getNestedReferences().size() - 1] =
- newLeafAttr;
- replacementRef =
- SymbolRefAttr::get(useRef.getRootReference(), nestedRefs);
+ auto walkFn = [&](Operation *op) -> Optional<WalkResult> {
+ auto remapAttrFn = [&](Attribute attr) -> Attribute {
+ if (attr == oldAttr)
+ return newAttr;
+ // Handle prefix matches.
+ if (SymbolRefAttr symRef = attr.dyn_cast<SymbolRefAttr>()) {
+ if (isReferencePrefixOf(oldAttr, symRef)) {
+ auto oldNestedRefs = oldAttr.getNestedReferences();
+ auto nestedRefs = symRef.getNestedReferences();
+ if (oldNestedRefs.empty())
+ return SymbolRefAttr::get(newSymbol, nestedRefs);
+
+ auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
+ newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
+ return SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs);
+ }
}
- }
-
- // If there was a previous operation, generate a new attribute dict
- // for it. This means that we've finished processing the current
- // operation, so generate a new dictionary for it.
- if (curOp && symbolUse.getUser() != curOp) {
- updatedAttrDicts.push_back({curOp, generateNewAttrDict()});
- accessChains.clear();
- }
-
- // Record this access.
- curOp = symbolUse.getUser();
- accessChains.push_back({llvm::to_vector<1>(accessChain), replacementRef});
+ return attr;
+ };
+ // Generate a new attribute dictionary by replacing references to the old
+ // symbol.
+ auto newDict = op->getAttrDictionary().replaceSubElements(remapAttrFn);
+ if (!newDict)
+ return WalkResult::interrupt();
+
+ op->setAttrs(newDict.template cast<DictionaryAttr>());
return WalkResult::advance();
};
- if (!scope.walk(walkFn))
+ if (!scope.walkSymbolTable(walkFn))
return failure();
-
- // Check to see if we have a dangling op that needs to be processed.
- if (curOp) {
- updatedAttrDicts.push_back({curOp, generateNewAttrDict()});
- curOp = nullptr;
- }
}
-
- // Update the attribute dictionaries as necessary.
- for (auto &it : updatedAttrDicts)
- it.first->setAttrs(it.second);
return success();
}
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 10c82b74282df..fb99d6f7c9b54 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -114,9 +114,8 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
DeclareAttrInterfaceMethods<SubElementAttrInterface,
- ["replaceImmediateSubAttribute"]>
+ ["replaceImmediateSubElements"]>
]> {
-
let mnemonic = "sub_elements_access";
let parameters = (ins
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 2decac866afa9..810380cf9ff19 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -173,25 +173,10 @@ void TestSubElementsAccessAttr::walkImmediateSubElements(
walkAttrsFn(getThird());
}
-SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
- ArrayRef<std::pair<size_t, Attribute>> replacements) const {
- Attribute first = getFirst();
- Attribute second = getSecond();
- Attribute third = getThird();
- for (auto &it : replacements) {
- switch (it.first) {
- case 0:
- first = it.second;
- break;
- case 1:
- second = it.second;
- break;
- case 2:
- third = it.second;
- break;
- }
- }
- return get(getContext(), first, second, third);
+Attribute TestSubElementsAccessAttr::replaceImmediateSubElements(
+ ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+ assert(replAttrs.size() == 3 && "invalid number of replacement attributes");
+ return get(getContext(), replAttrs[0], replAttrs[1], replAttrs[2]);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 8772efd8590fc..2ccf6ddf5f8f7 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -154,6 +154,12 @@ class TestRecursiveType
::llvm::function_ref<void(::mlir::Type)> walkTypesFn) const {
walkTypesFn(getBody());
}
+ Type replaceImmediateSubElements(llvm::ArrayRef<mlir::Attribute> replAttrs,
+ llvm::ArrayRef<mlir::Type> replTypes) const {
+ // TODO: It's not clear how we support replacing sub-elements of mutable
+ // types.
+ return nullptr;
+ }
};
} // namespace test
More information about the Mlir-commits
mailing list