[Mlir-commits] [mlir] 03d136c - [mlir] Promote the SubElementInterfaces to a core Attribute/Type construct
River Riddle
llvmlistbot at llvm.org
Fri Jan 27 15:28:23 PST 2023
Author: River Riddle
Date: 2023-01-27T15:28:03-08:00
New Revision: 03d136cf5f3f10b618b7e17f897ebf6019518dcc
URL: https://github.com/llvm/llvm-project/commit/03d136cf5f3f10b618b7e17f897ebf6019518dcc
DIFF: https://github.com/llvm/llvm-project/commit/03d136cf5f3f10b618b7e17f897ebf6019518dcc.diff
LOG: [mlir] Promote the SubElementInterfaces to a core Attribute/Type construct
This commit restructures the sub element infrastructure to be a core part
of attributes and types, instead of being relegated to an interface. This
establishes sub element walking/replacement as something "always there",
which makes it easier to rely on for correctness/etc (which various bits of
infrastructure want, such as Symbols).
Attribute/Type now have `walk` and `replace` methods directly
accessible, which provide power API for interacting with sub elements. As
part of this, a new AttrTypeWalker class is introduced that supports caching
walked attributes/types, and a friendlier API (see the simplification of symbol
walking in SymbolTable.cpp).
Differential Revision: https://reviews.llvm.org/D142272
Added:
mlir/include/mlir/IR/AttrTypeSubElements.h
mlir/lib/IR/AttrTypeSubElements.cpp
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
mlir/include/mlir/IR/AttributeSupport.h
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/BuiltinLocationAttributes.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/include/mlir/IR/CMakeLists.txt
mlir/include/mlir/IR/Location.h
mlir/include/mlir/IR/StorageUniquerSupport.h
mlir/include/mlir/IR/TypeRange.h
mlir/include/mlir/IR/TypeSupport.h
mlir/include/mlir/IR/Types.h
mlir/include/mlir/IR/Visitors.h
mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/CMakeLists.txt
mlir/lib/IR/ExtensibleDialect.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/IR/TypeDetail.h
mlir/lib/IR/Types.cpp
mlir/test/IR/test-symbol-rauw.mlir
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/Dialect/Test/TestTypes.h
mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
mlir/unittests/IR/AttributeTest.cpp
mlir/unittests/IR/CMakeLists.txt
mlir/unittests/IR/InterfaceTest.cpp
Removed:
mlir/include/mlir/IR/SubElementInterfaces.h
mlir/include/mlir/IR/SubElementInterfaces.td
mlir/lib/IR/SubElementInterfaces.cpp
mlir/unittests/IR/SubElementInterfaceTest.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 96936c4dcdc4..957f9eef908b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -12,7 +12,6 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/LLVMIR/LLVMEnums.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
-include "mlir/IR/SubElementInterfaces.td"
// All of the attributes will extend this class.
class LLVM_Attr<string name, string attrMnemonic,
@@ -160,9 +159,8 @@ def LLVM_DIBasicTypeAttr : LLVM_Attr<"DIBasicType", "di_basic_type",
// DICompileUnitAttr
//===----------------------------------------------------------------------===//
-def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit", [
- SubElementAttrInterface
- ], "DIScopeAttr"> {
+def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit",
+ /*traits=*/[], "DIScopeAttr"> {
let parameters = (ins
LLVM_DILanguageParameter:$sourceLanguage,
"DIFileAttr":$file,
@@ -177,9 +175,8 @@ def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit", [
// DICompositeTypeAttr
//===----------------------------------------------------------------------===//
-def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type", [
- SubElementAttrInterface
- ], "DITypeAttr"> {
+def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
+ /*traits=*/[], "DITypeAttr"> {
let parameters = (ins
LLVM_DITagParameter:$tag,
"StringAttr":$name,
@@ -199,9 +196,8 @@ def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
// DIDerivedTypeAttr
//===----------------------------------------------------------------------===//
-def LLVM_DIDerivedTypeAttr : LLVM_Attr<"DIDerivedType", "di_derived_type", [
- SubElementAttrInterface
- ], "DITypeAttr"> {
+def LLVM_DIDerivedTypeAttr : LLVM_Attr<"DIDerivedType", "di_derived_type",
+ /*traits=*/[], "DITypeAttr"> {
let parameters = (ins
LLVM_DITagParameter:$tag,
OptionalParameter<"StringAttr">:$name,
@@ -231,9 +227,8 @@ def LLVM_DIFileAttr : LLVM_Attr<"DIFile", "di_file", /*traits=*/[], "DIScopeAttr
// DILexicalBlockAttr
//===----------------------------------------------------------------------===//
-def LLVM_DILexicalBlockAttr : LLVM_Attr<"DILexicalBlock", "di_lexical_block", [
- SubElementAttrInterface
- ], "DIScopeAttr"> {
+def LLVM_DILexicalBlockAttr : LLVM_Attr<"DILexicalBlock", "di_lexical_block",
+ /*traits=*/[], "DIScopeAttr"> {
let parameters = (ins
"DIScopeAttr":$scope,
OptionalParameter<"DIFileAttr">:$file,
@@ -255,9 +250,8 @@ def LLVM_DILexicalBlockAttr : LLVM_Attr<"DILexicalBlock", "di_lexical_block", [
// DILexicalBlockFileAttr
//===----------------------------------------------------------------------===//
-def LLVM_DILexicalBlockFile : LLVM_Attr<"DILexicalBlockFile", "di_lexical_block_file", [
- SubElementAttrInterface
- ], "DIScopeAttr"> {
+def LLVM_DILexicalBlockFile : LLVM_Attr<"DILexicalBlockFile", "di_lexical_block_file",
+ /*traits=*/[], "DIScopeAttr"> {
let parameters = (ins
"DIScopeAttr":$scope,
OptionalParameter<"DIFileAttr">:$file,
@@ -277,9 +271,8 @@ def LLVM_DILexicalBlockFile : LLVM_Attr<"DILexicalBlockFile", "di_lexical_block_
// DILocalVariableAttr
//===----------------------------------------------------------------------===//
-def LLVM_DILocalVariableAttr : LLVM_Attr<"DILocalVariable", "di_local_variable", [
- SubElementAttrInterface
- ], "DINodeAttr"> {
+def LLVM_DILocalVariableAttr : LLVM_Attr<"DILocalVariable", "di_local_variable",
+ /*traits=*/[], "DINodeAttr"> {
let parameters = (ins
"DIScopeAttr":$scope,
"StringAttr":$name,
@@ -307,9 +300,8 @@ def LLVM_DILocalVariableAttr : LLVM_Attr<"DILocalVariable", "di_local_variable",
// DISubprogramAttr
//===----------------------------------------------------------------------===//
-def LLVM_DISubprogramAttr : LLVM_Attr<"DISubprogram", "di_subprogram", [
- SubElementAttrInterface
- ], "DIScopeAttr"> {
+def LLVM_DISubprogramAttr : LLVM_Attr<"DISubprogram", "di_subprogram",
+ /*traits=*/[], "DIScopeAttr"> {
let parameters = (ins
"DICompileUnitAttr":$compileUnit,
"DIScopeAttr":$scope,
@@ -357,9 +349,8 @@ def LLVM_DISubrangeAttr : LLVM_Attr<"DISubrange", "di_subrange", /*traits=*/[],
// DISubroutineTypeAttr
//===----------------------------------------------------------------------===//
-def LLVM_DISubroutineTypeAttr : LLVM_Attr<"DISubroutineType", "di_subroutine_type", [
- SubElementAttrInterface
- ], "DITypeAttr"> {
+def LLVM_DISubroutineTypeAttr : LLVM_Attr<"DISubroutineType", "di_subroutine_type",
+ /*traits=*/[], "DITypeAttr"> {
let parameters = (ins
LLVM_DICallingConventionParameter:$callingConvention,
OptionalArrayRefParameter<"DITypeAttr">:$types
@@ -377,9 +368,7 @@ def LLVM_DISubroutineTypeAttr : LLVM_Attr<"DISubroutineType", "di_subroutine_typ
// MemoryEffectsAttr
//===----------------------------------------------------------------------===//
-def LLVM_MemoryEffectsAttr : LLVM_Attr<"MemoryEffects", "memory_effects", [
- SubElementAttrInterface
- ]> {
+def LLVM_MemoryEffectsAttr : LLVM_Attr<"MemoryEffects", "memory_effects"> {
let parameters = (ins
"ModRefInfo":$other,
"ModRefInfo":$argMem,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index fcc35bdb856f..a44448e18b45 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -14,7 +14,6 @@
#ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
#define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
-#include "mlir/IR/SubElementInterfaces.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include <optional>
@@ -104,7 +103,6 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
class LLVMStructType
: public Type::TypeBase<LLVMStructType, Type, detail::LLVMStructTypeStorage,
DataLayoutTypeInterface::Trait,
- SubElementTypeInterface::Trait,
TypeTrait::IsMutable> {
public:
/// Inherit base constructors.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 460e7ded84c3..caf4b58b87f5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -11,7 +11,6 @@
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/IR/AttrTypeBase.td"
-include "mlir/IR/SubElementInterfaces.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
/// Base class for all LLVM dialect types.
@@ -25,8 +24,7 @@ class LLVMType<string typeName, string typeMnemonic, list<Trait> traits = []>
//===----------------------------------------------------------------------===//
def LLVMArrayType : LLVMType<"LLVMArray", "array", [
- DeclareTypeInterfaceMethods<DataLayoutTypeInterface, ["getTypeSize"]>,
- DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+ DeclareTypeInterfaceMethods<DataLayoutTypeInterface, ["getTypeSize"]>]> {
let summary = "LLVM array type";
let description = [{
The `!llvm.array` type represents a fixed-size array of element types.
@@ -62,8 +60,7 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [
// LLVMFunctionType
//===----------------------------------------------------------------------===//
-def LLVMFunctionType : LLVMType<"LLVMFunction", "func", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+def LLVMFunctionType : LLVMType<"LLVMFunction", "func"> {
let summary = "LLVM function type";
let description = [{
The `!llvm.func` is a function type. It consists of a single return type
@@ -124,8 +121,7 @@ def LLVMFunctionType : LLVMType<"LLVMFunction", "func", [
def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
- "areCompatible", "verifyEntries"]>,
- DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+ "areCompatible", "verifyEntries"]>]> {
let summary = "LLVM pointer type";
let description = [{
The `!llvm.ptr` type is an LLVM pointer type. This type typically represents
@@ -171,8 +167,7 @@ def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
// LLVMFixedVectorType
//===----------------------------------------------------------------------===//
-def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec"> {
let summary = "LLVM fixed vector type";
let description = [{
LLVM dialect scalable vector type, represents a sequence of elements of
@@ -202,8 +197,7 @@ def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec", [
// LLVMScalableVectorType
//===----------------------------------------------------------------------===//
-def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec"> {
let summary = "LLVM scalable vector type";
let description = [{
LLVM dialect scalable vector type, represents a sequence of elements of
diff --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/AttrTypeSubElements.h
similarity index 66%
rename from mlir/include/mlir/IR/SubElementInterfaces.h
rename to mlir/include/mlir/IR/AttrTypeSubElements.h
index 935d7fcd59cf..23b8d533b7cd 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.h
+++ b/mlir/include/mlir/IR/AttrTypeSubElements.h
@@ -1,4 +1,4 @@
-//===- SubElementInterfaces.h - Attr and Type SubElements -------*- C++ -*-===//
+//===- AttrTypeSubElements.h - Attr and Type SubElements -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,20 +6,112 @@
//
//===----------------------------------------------------------------------===//
//
-// This file contains interfaces and utilities for querying the sub elements of
-// an attribute or type.
+// This file contains utilities for querying the sub elements of an attribute or
+// type.
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_IR_SUBELEMENTINTERFACES_H
-#define MLIR_IR_SUBELEMENTINTERFACES_H
+#ifndef MLIR_IR_ATTRTYPESUBELEMENTS_H
+#define MLIR_IR_ATTRTYPESUBELEMENTS_H
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Types.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Visitors.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
#include <optional>
namespace mlir {
+class Attribute;
+class Type;
+
+//===----------------------------------------------------------------------===//
+/// AttrTypeWalker
+//===----------------------------------------------------------------------===//
+
+/// This class provides a utility for walking attributes/types, and their sub
+/// elements. Multiple walk functions may be registered.
+class AttrTypeWalker {
+public:
+ //===--------------------------------------------------------------------===//
+ // Application
+ //===--------------------------------------------------------------------===//
+
+ /// Walk the given attribute/type, and recursively walk any sub elements.
+ template <WalkOrder Order, typename T>
+ WalkResult walk(T element) {
+ return walkImpl(element, Order);
+ }
+ template <typename T>
+ WalkResult walk(T element) {
+ return walk<WalkOrder::PostOrder, T>(element);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Registration
+ //===--------------------------------------------------------------------===//
+
+ template <typename T>
+ using WalkFn = std::function<WalkResult(T)>;
+
+ /// Register a walk function for a given attribute or type. A walk function
+ /// must be convertible to any of the following forms(where `T` is a class
+ /// derived from `Type` or `Attribute`:
+ ///
+ /// * WalkResult(T)
+ /// - Returns a walk result, which can be used to control the walk
+ ///
+ /// * void(T)
+ /// - Returns void, i.e. the walk always continues.
+ ///
+ /// Note: When walking, the mostly recently added walk functions will be
+ /// invoked first.
+ void addWalk(WalkFn<Attribute> &&fn) {
+ attrWalkFns.emplace_back(std::move(fn));
+ }
+ void addWalk(WalkFn<Type> &&fn) { typeWalkFns.push_back(std::move(fn)); }
+
+ /// Register a replacement function that doesn't match the default signature,
+ /// either because it uses a derived parameter type, or it uses a simplified
+ /// result type.
+ template <typename FnT,
+ typename T = typename llvm::function_traits<
+ std::decay_t<FnT>>::template arg_t<0>,
+ typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
+ Attribute, Type>,
+ typename ResultT = std::invoke_result_t<FnT, T>>
+ std::enable_if_t<!std::is_same_v<T, BaseT> || std::is_same_v<ResultT, void>>
+ addWalk(FnT &&callback) {
+ addWalk([callback = std::forward<FnT>(callback)](BaseT base) -> WalkResult {
+ if (auto derived = dyn_cast<T>(base)) {
+ if constexpr (std::is_convertible_v<ResultT, WalkResult>)
+ return callback(derived);
+ else
+ callback(derived);
+ }
+ return WalkResult::advance();
+ });
+ }
+
+private:
+ WalkResult walkImpl(Attribute attr, WalkOrder order);
+ WalkResult walkImpl(Type type, WalkOrder order);
+
+ /// Internal implementation of the `walk` methods above.
+ template <typename T, typename WalkFns>
+ WalkResult walkImpl(T element, WalkFns &walkFns, WalkOrder order);
+
+ /// Walk the sub elements of the given interface.
+ template <typename T>
+ WalkResult walkSubElements(T interface, WalkOrder order);
+
+ /// The set of walk functions that map sub elements.
+ std::vector<WalkFn<Attribute>> attrWalkFns;
+ std::vector<WalkFn<Type>> typeWalkFns;
+
+ /// The set of visited attributes/types.
+ DenseMap<std::pair<const void *, int>, WalkResult> visitedAttrTypes;
+};
+
//===----------------------------------------------------------------------===//
/// AttrTypeReplacer
//===----------------------------------------------------------------------===//
@@ -84,12 +176,8 @@ class AttrTypeReplacer {
///
/// Note: When replacing, the mostly recently added replacement functions will
/// be invoked first.
- void addReplacement(ReplaceFn<Attribute> fn) {
- attrReplacementFns.emplace_back(std::move(fn));
- }
- void addReplacement(ReplaceFn<Type> fn) {
- typeReplacementFns.push_back(std::move(fn));
- }
+ void addReplacement(ReplaceFn<Attribute> fn);
+ void addReplacement(ReplaceFn<Type> fn);
/// Register a replacement function that doesn't match the default signature,
/// either because it uses a derived parameter type, or it uses a simplified
@@ -120,20 +208,19 @@ class AttrTypeReplacer {
private:
/// Internal implementation of the `replace` methods above.
- template <typename InterfaceT, typename ReplaceFns, typename T>
- T replaceImpl(T element, ReplaceFns &replaceFns, DenseMap<T, T> &map);
+ template <typename T, typename ReplaceFns>
+ T replaceImpl(T element, ReplaceFns &replaceFns);
/// Replace the sub elements of the given interface.
- template <typename InterfaceT, typename T = typename InterfaceT::ValueType>
- T replaceSubElements(InterfaceT interface, DenseMap<T, T> &interfaceMap);
+ template <typename T>
+ T replaceSubElements(T interface);
/// The set of replacement functions that map sub elements.
std::vector<ReplaceFn<Attribute>> attrReplacementFns;
std::vector<ReplaceFn<Type>> typeReplacementFns;
/// The set of cached mappings for attributes/types.
- DenseMap<Attribute, Attribute> attrMap;
- DenseMap<Type, Type> typeMap;
+ DenseMap<const void *, const void *> attrTypeMap;
};
//===----------------------------------------------------------------------===//
@@ -142,22 +229,16 @@ class AttrTypeReplacer {
/// This class is used by AttrTypeSubElementHandler instances to walking sub
/// attributes and types.
-class AttrTypeSubElementWalker {
+class AttrTypeImmediateSubElementWalker {
public:
- AttrTypeSubElementWalker(function_ref<void(Attribute)> walkAttrsFn,
- function_ref<void(Type)> walkTypesFn)
+ AttrTypeImmediateSubElementWalker(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn)
: walkAttrsFn(walkAttrsFn), walkTypesFn(walkTypesFn) {}
/// Walk an attribute.
- void walk(Attribute element) {
- if (element)
- walkAttrsFn(element);
- }
+ void walk(Attribute element);
/// Walk a type.
- void walk(Type element) {
- if (element)
- walkTypesFn(element);
- }
+ void walk(Type element);
/// Walk a range of attributes or types.
template <typename RangeT>
void walkRange(RangeT &&elements) {
@@ -212,7 +293,8 @@ using TypeSubElementReplacements = AttrTypeSubElementReplacements<Type>;
template <typename T, typename Enable = void>
struct AttrTypeSubElementHandler {
/// Default walk implementation that does nothing.
- static inline void walk(const T ¶m, AttrTypeSubElementWalker &walker) {}
+ static inline void walk(const T ¶m,
+ AttrTypeImmediateSubElementWalker &walker) {}
/// Default replace implementation just forwards the parameter.
template <typename ParamT>
@@ -241,7 +323,7 @@ template <typename T>
struct AttrTypeSubElementHandler<
T, std::enable_if_t<std::is_base_of_v<Attribute, T> ||
std::is_base_of_v<Type, T>>> {
- static void walk(T param, AttrTypeSubElementWalker &walker) {
+ static void walk(T param, AttrTypeImmediateSubElementWalker &walker) {
walker.walk(param);
}
static T replace(T param, AttrSubElementReplacements &attrRepls,
@@ -255,27 +337,14 @@ struct AttrTypeSubElementHandler<
}
}
};
-template <>
-struct AttrTypeSubElementHandler<NamedAttribute> {
- template <typename T>
- static void walk(T param, AttrTypeSubElementWalker &walker) {
- walker.walk(param.getName());
- walker.walk(param.getValue());
- }
- template <typename T>
- static T replace(T param, AttrSubElementReplacements &attrRepls,
- TypeSubElementReplacements &typeRepls) {
- ArrayRef<Attribute> paramRepls = attrRepls.take_front(2);
- return T(cast<decltype(param.getName())>(paramRepls[0]), paramRepls[1]);
- }
-};
/// Implementation for derived ArrayRef.
template <typename T>
struct AttrTypeSubElementHandler<ArrayRef<T>,
std::enable_if_t<has_sub_attr_or_type_v<T>>> {
using EltHandler = AttrTypeSubElementHandler<T>;
- static void walk(ArrayRef<T> param, AttrTypeSubElementWalker &walker) {
+ static void walk(ArrayRef<T> param,
+ AttrTypeImmediateSubElementWalker &walker) {
for (const T &subElement : param)
EltHandler::walk(subElement, walker);
}
@@ -283,11 +352,11 @@ struct AttrTypeSubElementHandler<ArrayRef<T>,
TypeSubElementReplacements &typeRepls) {
// Normal attributes/types can extract using the replacer directly.
if constexpr (std::is_base_of_v<Attribute, T> &&
- sizeof(T) == sizeof(Attribute)) {
+ sizeof(T) == sizeof(void *)) {
ArrayRef<Attribute> attrs = attrRepls.take_front(param.size());
return ArrayRef<T>((const T *)attrs.data(), attrs.size());
} else if constexpr (std::is_base_of_v<Type, T> &&
- sizeof(T) == sizeof(Type)) {
+ sizeof(T) == sizeof(void *)) {
ArrayRef<Type> types = typeRepls.take_front(param.size());
return ArrayRef<T>((const T *)types.data(), types.size());
} else {
@@ -305,7 +374,7 @@ template <typename... Ts>
struct AttrTypeSubElementHandler<
std::tuple<Ts...>, std::enable_if_t<has_sub_attr_or_type_v<Ts...>>> {
static void walk(const std::tuple<Ts...> ¶m,
- AttrTypeSubElementWalker &walker) {
+ AttrTypeImmediateSubElementWalker &walker) {
std::apply(
[&](const Ts &...params) {
(AttrTypeSubElementHandler<Ts>::walk(params, walker), ...);
@@ -333,6 +402,8 @@ template <typename... Ts>
struct is_tuple<std::tuple<Ts...>> : public std::true_type {};
template <typename T, typename... Ts>
using has_get_method = decltype(T::get(std::declval<Ts>()...));
+template <typename T, typename... Ts>
+using has_get_as_key = decltype(std::declval<T>().getAsKey());
/// This function provides the underlying implementation for the
/// SubElementInterface walk method, using the key type of the derived
@@ -341,21 +412,24 @@ template <typename T>
void walkImmediateSubElementsImpl(T derived,
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
- auto key = static_cast<typename T::ImplType *>(derived.getImpl())->getAsKey();
+ using ImplT = typename T::ImplType;
+ if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) {
+ auto key = static_cast<ImplT *>(derived.getImpl())->getAsKey();
- // If we don't have any sub-elements, there is nothing to do.
- if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
- return;
- } else {
- AttrTypeSubElementWalker walker(walkAttrsFn, walkTypesFn);
- AttrTypeSubElementHandler<decltype(key)>::walk(key, walker);
+ // If we don't have any sub-elements, there is nothing to do.
+ if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
+ return;
+ } else {
+ AttrTypeImmediateSubElementWalker walker(walkAttrsFn, walkTypesFn);
+ AttrTypeSubElementHandler<decltype(key)>::walk(key, walker);
+ }
}
}
/// This function invokes the proper `get` method for a type `T` with the given
/// values.
template <typename T, typename... Ts>
-T constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) {
+auto constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) {
// Prefer a direct `get` method if one exists.
if constexpr (llvm::is_detected<has_get_method, T, Ts...>::value) {
(void)ctx;
@@ -373,38 +447,39 @@ T constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) {
/// SubElementInterface replace method, using the key type of the derived
/// attribute/type to interact with the individual parameters.
template <typename T>
-T replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
- ArrayRef<Type> &replTypes) {
- auto key = static_cast<typename T::ImplType *>(derived.getImpl())->getAsKey();
+auto replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
+ ArrayRef<Type> &replTypes) {
+ using ImplT = typename T::ImplType;
+ if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) {
+ auto key = static_cast<ImplT *>(derived.getImpl())->getAsKey();
- // If we don't have any sub-elements, we can just return the original.
- if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
- return derived;
+ // If we don't have any sub-elements, we can just return the original.
+ if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
+ return derived;
- // Otherwise, we need to replace any necessary sub-elements.
- } else {
- AttrSubElementReplacements attrRepls(replAttrs);
- TypeSubElementReplacements typeRepls(replTypes);
- auto newKey = AttrTypeSubElementHandler<decltype(key)>::replace(
- key, attrRepls, typeRepls);
- if constexpr (is_tuple<decltype(key)>::value) {
- return std::apply(
- [&](auto &&...params) {
- return constructSubElementReplacement<T>(
- derived.getContext(),
- std::forward<decltype(params)>(params)...);
- },
- newKey);
+ // Otherwise, we need to replace any necessary sub-elements.
} else {
- return constructSubElementReplacement<T>(derived.getContext(), newKey);
+ AttrSubElementReplacements attrRepls(replAttrs);
+ TypeSubElementReplacements typeRepls(replTypes);
+ auto newKey = AttrTypeSubElementHandler<decltype(key)>::replace(
+ key, attrRepls, typeRepls);
+ if constexpr (is_tuple<decltype(key)>::value) {
+ return std::apply(
+ [&](auto &&...params) {
+ return constructSubElementReplacement<T>(
+ derived.getContext(),
+ std::forward<decltype(params)>(params)...);
+ },
+ newKey);
+ } else {
+ return constructSubElementReplacement<T>(derived.getContext(), newKey);
+ }
}
+ } else {
+ return derived;
}
}
} // namespace detail
} // namespace mlir
-/// Include the definitions of the sub element interfaces.
-#include "mlir/IR/SubElementAttrInterfaces.h.inc"
-#include "mlir/IR/SubElementTypeInterfaces.h.inc"
-
-#endif // MLIR_IR_SUBELEMENTINTERFACES_H
+#endif // MLIR_IR_ATTRTYPESUBELEMENTS_H
diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index 73ede2e4d481..796adbe8531a 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -20,9 +20,6 @@
#include "llvm/ADT/Twine.h"
namespace mlir {
-class MLIRContext;
-class Type;
-
//===----------------------------------------------------------------------===//
// AbstractAttribute
//===----------------------------------------------------------------------===//
@@ -32,6 +29,10 @@ class Type;
class AbstractAttribute {
public:
using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
+ using WalkImmediateSubElementsFn = function_ref<void(
+ Attribute, function_ref<void(Attribute)>, function_ref<void(Type)>)>;
+ using ReplaceImmediateSubElementsFn =
+ function_ref<Attribute(Attribute, ArrayRef<Attribute>, ArrayRef<Type>)>;
/// Look up the specified abstract attribute in the MLIRContext and return a
/// reference to it.
@@ -42,6 +43,8 @@ class AbstractAttribute {
template <typename T>
static AbstractAttribute get(Dialect &dialect) {
return AbstractAttribute(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
+ T::getWalkImmediateSubElementsFn(),
+ T::getReplaceImmediateSubElementsFn(),
T::getTypeID());
}
@@ -49,11 +52,15 @@ class AbstractAttribute {
/// custom TypeIDs.
/// The use of this method is in general discouraged in favor of
/// 'get<CustomAttribute>(dialect)'.
- static AbstractAttribute get(Dialect &dialect,
- detail::InterfaceMap &&interfaceMap,
- HasTraitFn &&hasTrait, TypeID typeID) {
+ static AbstractAttribute
+ get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
+ HasTraitFn &&hasTrait,
+ WalkImmediateSubElementsFn walkImmediateSubElementsFn,
+ ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
+ TypeID typeID) {
return AbstractAttribute(dialect, std::move(interfaceMap),
- std::move(hasTrait), typeID);
+ std::move(hasTrait), walkImmediateSubElementsFn,
+ replaceImmediateSubElementsFn, typeID);
}
/// Return the dialect this attribute was registered to.
@@ -82,14 +89,30 @@ class AbstractAttribute {
/// Returns true if the attribute has a particular trait.
bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
+ /// Walk the immediate sub-elements of this attribute.
+ void walkImmediateSubElements(Attribute attr,
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
+
+ /// Replace the immediate sub-elements of this attribute.
+ Attribute replaceImmediateSubElements(Attribute attr,
+ ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const;
+
/// Return the unique identifier representing the concrete attribute class.
TypeID getTypeID() const { return typeID; }
private:
AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
- HasTraitFn &&hasTrait, TypeID typeID)
+ HasTraitFn &&hasTraitFn,
+ WalkImmediateSubElementsFn walkImmediateSubElementsFn,
+ ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
+ TypeID typeID)
: dialect(dialect), interfaceMap(std::move(interfaceMap)),
- hasTraitFn(std::move(hasTrait)), typeID(typeID) {}
+ hasTraitFn(std::move(hasTraitFn)),
+ walkImmediateSubElementsFn(walkImmediateSubElementsFn),
+ replaceImmediateSubElementsFn(replaceImmediateSubElementsFn),
+ typeID(typeID) {}
/// Give StorageUserBase access to the mutable lookup.
template <typename ConcreteT, typename BaseT, typename StorageT,
@@ -110,6 +133,12 @@ class AbstractAttribute {
/// Function to check if the attribute has a particular trait.
HasTraitFn hasTraitFn;
+ /// Function to walk the immediate sub-elements of this attribute.
+ WalkImmediateSubElementsFn walkImmediateSubElementsFn;
+
+ /// Function to replace the immediate sub-elements of this attribute.
+ ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn;
+
/// The unique identifier of the derived Attribute class.
const TypeID typeID;
};
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index a31560222ecf..a39ddf59df0a 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -101,6 +101,48 @@ class Attribute {
return impl->getAbstractAttribute();
}
+ /// Walk all of the immediately nested sub-attributes and sub-types. This
+ /// method does not recurse into sub elements.
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ getAbstractAttribute().walkImmediateSubElements(*this, walkAttrsFn,
+ walkTypesFn);
+ }
+
+ /// Replace the immediately nested sub-attributes and sub-types with those
+ /// provided. The order of the provided elements is derived from the order of
+ /// the elements returned by the callbacks of `walkImmediateSubElements`. The
+ /// element at index 0 would replace the very first attribute given by
+ /// `walkImmediateSubElements`. On success, the new instance with the values
+ /// replaced is returned. If replacement fails, nullptr is returned.
+ auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const {
+ return getAbstractAttribute().replaceImmediateSubElements(*this, replAttrs,
+ replTypes);
+ }
+
+ /// Walk this attribute and all attibutes/types nested within using the
+ /// provided walk functions. See `AttrTypeWalker` for information on the
+ /// supported walk function types.
+ template <WalkOrder Order = WalkOrder::PostOrder, typename... WalkFns>
+ auto walk(WalkFns &&...walkFns) {
+ AttrTypeWalker walker;
+ (walker.addWalk(std::forward<WalkFns>(walkFns)), ...);
+ return walker.walk<Order>(*this);
+ }
+
+ /// Recursively replace all of the nested sub-attributes and sub-types using
+ /// the provided map functions. Returns nullptr in the case of failure. See
+ /// `AttrTypeReplacer` for information on the support replacement function
+ /// types.
+ template <typename... ReplacementFns>
+ auto replace(ReplacementFns &&...replacementFns) {
+ AttrTypeReplacer replacer;
+ (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)),
+ ...);
+ return replacer.replace(*this);
+ }
+
/// Return the internal Attribute implementation.
ImplType *getImpl() const { return impl; }
@@ -201,6 +243,22 @@ inline ::llvm::hash_code hash_value(const NamedAttribute &arg) {
return DenseMapInfo<AttrPairT>::getHashValue(AttrPairT(arg.name, arg.value));
}
+/// Allow walking and replacing the subelements of a NamedAttribute.
+template <>
+struct AttrTypeSubElementHandler<NamedAttribute> {
+ template <typename T>
+ static void walk(T param, AttrTypeImmediateSubElementWalker &walker) {
+ walker.walk(param.getName());
+ walker.walk(param.getValue());
+ }
+ template <typename T>
+ static T replace(T param, AttrSubElementReplacements &attrRepls,
+ TypeSubElementReplacements &typeRepls) {
+ ArrayRef<Attribute> paramRepls = attrRepls.take_front(2);
+ return T(cast<decltype(param.getName())>(paramRepls[0]), paramRepls[1]);
+ }
+};
+
//===----------------------------------------------------------------------===//
// AttributeTraitBase
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index a23da78a48d6..d8506cbcdad1 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -10,7 +10,6 @@
#define MLIR_IR_BUILTINATTRIBUTES_H
#include "mlir/IR/BuiltinAttributeInterfaces.h"
-#include "mlir/IR/SubElementInterfaces.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Sequence.h"
#include <complex>
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 0e7b5012b082..446bc375d3fa 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -18,7 +18,6 @@ include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
-include "mlir/IR/SubElementInterfaces.td"
// TODO: Currently the attributes defined in this file are prefixed with
// `Builtin_`. This is to
diff erentiate the attributes here with the ones in
@@ -71,9 +70,7 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
// ArrayAttr
//===----------------------------------------------------------------------===//
-def Builtin_ArrayAttr : Builtin_Attr<"Array", [
- SubElementAttrInterface
- ]> {
+def Builtin_ArrayAttr : Builtin_Attr<"Array"> {
let summary = "A collection of other Attribute values";
let description = [{
Syntax:
@@ -491,9 +488,7 @@ def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [
// DictionaryAttr
//===----------------------------------------------------------------------===//
-def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
- SubElementAttrInterface
- ]> {
+def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> {
let summary = "An dictionary of named Attribute values";
let description = [{
Syntax:
@@ -1096,9 +1091,7 @@ def Builtin_StringAttr : Builtin_Attr<"String", [TypedAttrInterface]> {
// SymbolRefAttr
//===----------------------------------------------------------------------===//
-def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", [
- SubElementAttrInterface
- ]> {
+def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
let summary = "An Attribute containing a symbolic reference to an Operation";
let description = [{
Syntax:
@@ -1114,13 +1107,6 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", [
may optionally contain a set of nested references that further resolve to a
symbol nested within a
diff erent symbol table.
- This attribute can only be held internally by
- [array attributes](#array-attribute),
- [dictionary attributes](#dictionary-attribute)(including the top-level
- operation attribute dictionary) as well as attributes exposing it via
- the `SubElementAttrInterface` interface. Symbol reference attributes
- nested in types are currently not supported.
-
**Rationale:** Identifying accesses to global data is critical to
enabling efficient multi-threaded compilation. Restricting global
data access to occur through symbols and limiting the places that can
@@ -1171,9 +1157,7 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", [
// TypeAttr
//===----------------------------------------------------------------------===//
-def Builtin_TypeAttr : Builtin_Attr<"Type", [
- SubElementAttrInterface
- ]> {
+def Builtin_TypeAttr : Builtin_Attr<"Type"> {
let summary = "An Attribute containing a Type";
let description = [{
Syntax:
diff --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
index 0395e1329590..3c9d2c57f4d1 100644
--- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
@@ -15,7 +15,6 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
-include "mlir/IR/SubElementInterfaces.td"
// Base class for Builtin dialect location attributes.
class Builtin_LocationAttr<string name, list<Trait> traits = []>
@@ -28,9 +27,7 @@ class Builtin_LocationAttr<string name, list<Trait> traits = []>
// CallSiteLoc
//===----------------------------------------------------------------------===//
-def CallSiteLoc : Builtin_LocationAttr<"CallSiteLoc", [
- SubElementAttrInterface
- ]> {
+def CallSiteLoc : Builtin_LocationAttr<"CallSiteLoc"> {
let summary = "A callsite source location";
let description = [{
Syntax:
@@ -107,9 +104,7 @@ def FileLineColLoc : Builtin_LocationAttr<"FileLineColLoc"> {
// FusedLoc
//===----------------------------------------------------------------------===//
-def FusedLoc : Builtin_LocationAttr<"FusedLoc", [
- SubElementAttrInterface
- ]> {
+def FusedLoc : Builtin_LocationAttr<"FusedLoc"> {
let summary = "A tuple of other source locations";
let description = [{
Syntax:
@@ -148,9 +143,7 @@ def FusedLoc : Builtin_LocationAttr<"FusedLoc", [
// NameLoc
//===----------------------------------------------------------------------===//
-def NameLoc : Builtin_LocationAttr<"NameLoc", [
- SubElementAttrInterface
- ]> {
+def NameLoc : Builtin_LocationAttr<"NameLoc"> {
let summary = "A named source location";
let description = [{
Syntax:
@@ -187,9 +180,7 @@ def NameLoc : Builtin_LocationAttr<"NameLoc", [
// OpaqueLoc
//===----------------------------------------------------------------------===//
-def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc", [
- SubElementAttrInterface
- ]> {
+def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc"> {
let summary = "An opaque source location";
let description = [{
An instance of this location essentially contains a pointer to some data
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f06581a4b0d7..135fa9b559d8 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -11,7 +11,6 @@
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
-#include "mlir/IR/SubElementInterfaces.h"
namespace llvm {
class BitVector;
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 8b7bfafab568..5f9141d71b69 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,7 +17,6 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
-include "mlir/IR/SubElementInterfaces.td"
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
// This is to
diff erentiate the types here with the ones in OpBase.td. We should
@@ -165,9 +164,7 @@ def Builtin_Float128 : Builtin_FloatType<"Float128"> {
// FunctionType
//===----------------------------------------------------------------------===//
-def Builtin_Function : Builtin_Type<"Function", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>
- ]> {
+def Builtin_Function : Builtin_Type<"Function"> {
let summary = "Map from a list of inputs to a list of results";
let description = [{
Syntax:
@@ -314,7 +311,7 @@ def Builtin_Integer : Builtin_Type<"Integer"> {
//===----------------------------------------------------------------------===//
def Builtin_MemRef : Builtin_Type<"MemRef", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
+ ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference to a region of memory";
let description = [{
@@ -649,7 +646,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque"> {
//===----------------------------------------------------------------------===//
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
+ ShapedTypeInterface
], "TensorType"> {
let summary = "Multi-dimensional array with a fixed number of dimensions";
let description = [{
@@ -753,9 +750,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
// TupleType
//===----------------------------------------------------------------------===//
-def Builtin_Tuple : Builtin_Type<"Tuple", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>
- ]> {
+def Builtin_Tuple : Builtin_Type<"Tuple"> {
let summary = "Fixed-sized collection of other types";
let description = [{
Syntax:
@@ -823,7 +818,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", [
//===----------------------------------------------------------------------===//
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
+ ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference, with unknown rank, to a region of memory";
let description = [{
@@ -895,7 +890,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
//===----------------------------------------------------------------------===//
def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
+ ShapedTypeInterface
], "TensorType"> {
let summary = "Multi-dimensional array with unknown dimensions";
let description = [{
@@ -943,9 +938,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
// VectorType
//===----------------------------------------------------------------------===//
-def Builtin_Vector : Builtin_Type<"Vector", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
- ], "Type"> {
+def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
let description = [{
Syntax:
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 92f55b557baf..78d41d6dc4ab 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -43,13 +43,6 @@ mlir_tablegen(FunctionOpInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRFunctionInterfacesIncGen)
add_dependencies(mlir-generic-headers MLIRFunctionInterfacesIncGen)
-set(LLVM_TARGET_DEFINITIONS SubElementInterfaces.td)
-mlir_tablegen(SubElementAttrInterfaces.h.inc -gen-attr-interface-decls)
-mlir_tablegen(SubElementAttrInterfaces.cpp.inc -gen-attr-interface-defs)
-mlir_tablegen(SubElementTypeInterfaces.h.inc -gen-type-interface-decls)
-mlir_tablegen(SubElementTypeInterfaces.cpp.inc -gen-type-interface-defs)
-add_public_tablegen_target(MLIRSubElementInterfacesIncGen)
-
set(LLVM_TARGET_DEFINITIONS TensorEncoding.td)
mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs)
diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index b772cf4b90e3..63b12899e249 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -15,7 +15,6 @@
#define MLIR_IR_LOCATION_H
#include "mlir/IR/Attributes.h"
-#include "mlir/IR/SubElementInterfaces.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
namespace mlir {
@@ -172,13 +171,13 @@ inline OpaqueLoc OpaqueLoc::get(T underlyingLocation, MLIRContext *context) {
}
//===----------------------------------------------------------------------===//
-// SubElementInterfaces
+// SubElements
//===----------------------------------------------------------------------===//
/// Enable locations to be introspected as sub-elements.
template <>
struct AttrTypeSubElementHandler<Location> {
- static void walk(Location param, AttrTypeSubElementWalker &walker) {
+ static void walk(Location param, AttrTypeImmediateSubElementWalker &walker) {
walker.walk(param);
}
static Location replace(Location param, AttrSubElementReplacements &attrRepls,
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 128ad815556c..061bd67e1090 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -13,6 +13,7 @@
#ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H
#define MLIR_IR_STORAGEUNIQUERSUPPORT_H
+#include "mlir/IR/AttrTypeSubElements.h"
#include "mlir/Support/InterfaceSupport.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/StorageUniquer.h"
@@ -126,6 +127,51 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
};
}
+ /// Walk all of the immediately nested sub-attributes and sub-types. This
+ /// method does not recurse into sub elements.
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ ::mlir::detail::walkImmediateSubElementsImpl(
+ *static_cast<const ConcreteT *>(this), walkAttrsFn, walkTypesFn);
+ }
+
+ /// Replace the immediately nested sub-attributes and sub-types with those
+ /// provided. The order of the provided elements is derived from the order of
+ /// the elements returned by the callbacks of `walkImmediateSubElements`. The
+ /// element at index 0 would replace the very first attribute given by
+ /// `walkImmediateSubElements`. On success, the new instance with the values
+ /// replaced is returned. If replacement fails, nullptr is returned.
+ ///
+ /// Note that replacing the sub-elements of mutable types or attributes is
+ /// not currently supported by the interface. If an implementing type or
+ /// attribute is mutable, it should return `nullptr` if it has no mechanism
+ /// for replacing sub elements.
+ auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const {
+ return ::mlir::detail::replaceImmediateSubElementsImpl(
+ *static_cast<const ConcreteT *>(this), replAttrs, replTypes);
+ }
+
+ /// Returns a function that walks immediate sub elements of a given instance
+ /// of the storage user.
+ static auto getWalkImmediateSubElementsFn() {
+ return [](auto instance, function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) {
+ cast<ConcreteT>(instance).walkImmediateSubElements(walkAttrsFn,
+ walkTypesFn);
+ };
+ }
+
+ /// Returns a function that replaces immediate sub elements of a given
+ /// instance of the storage user.
+ static auto getReplaceImmediateSubElementsFn() {
+ return [](auto instance, ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) {
+ return cast<ConcreteT>(instance).replaceImmediateSubElements(replAttrs,
+ replTypes);
+ };
+ }
+
/// Attach the given models as implementations of the corresponding interfaces
/// for the concrete storage user class. The type must be registered with the
/// context, i.e. the dialect to which the type belongs must be loaded. The
diff --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td
deleted file mode 100644
index d7054b422181..000000000000
--- a/mlir/include/mlir/IR/SubElementInterfaces.td
+++ /dev/null
@@ -1,142 +0,0 @@
-//===-- SubElementInterfaces.td - Sub-Element Interfaces ---*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file contains a set of interfaces that can be used to interface with
-// sub-elements, e.g. held attributes and types, of a composite attribute or
-// type.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_SUBELEMENTINTERFACES_TD_
-#define MLIR_IR_SUBELEMENTINTERFACES_TD_
-
-include "mlir/IR/OpBase.td"
-
-//===----------------------------------------------------------------------===//
-// SubElementInterfaceBase
-//===----------------------------------------------------------------------===//
-
-class SubElementInterfaceBase<string interfaceName, string attrOrType,
- string derivedValue> {
- string cppNamespace = "::mlir";
-
- list<InterfaceMethod> methods = [
- InterfaceMethod<
- /*desc=*/[{
- Walk all of the immediately nested sub-attributes and sub-types. This
- method does not recurse into sub elements.
- }], "void", "walkImmediateSubElements",
- (ins "llvm::function_ref<void(mlir::Attribute)>":$walkAttrsFn,
- "llvm::function_ref<void(mlir::Type)>":$walkTypesFn),
- /*methodBody=*/[{}], /*defaultImplementation=*/[{
- ::mlir::detail::walkImmediateSubElementsImpl(
- }] # derivedValue # [{, walkAttrsFn, walkTypesFn);
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Replace the immediately nested sub-attributes and sub-types with those provided.
- The order of the provided elements is derived from the order of the elements
- returned by the callbacks of `walkImmediateSubElements`. The element at index 0
- would replace the very first attribute given by `walkImmediateSubElements`.
- On success, the new instance with the values replaced is returned. If replacement
- fails, nullptr is returned.
-
- Note that replacing the sub-elements of mutable types or attributes is
- not currently supported by the interface. If an implementing type or
- attribute is mutable, it should return `nullptr` if it has no mechanism
- for replacing sub elements.
- }], attrOrType, "replaceImmediateSubElements",
- (ins "::llvm::ArrayRef<::mlir::Attribute>":$replAttrs,
- "::llvm::ArrayRef<::mlir::Type>":$replTypes),
- /*methodBody=*/[{}], /*defaultImplementation=*/[{
- return ::mlir::detail::replaceImmediateSubElementsImpl(
- }] # derivedValue # [{, replAttrs, replTypes);
- }]>,
- ];
-
- code extraClassDeclaration = [{
- /// Walk all of the held sub-attributes and sub-types.
- void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
- llvm::function_ref<void(mlir::Type)> walkTypesFn);
-
- /// Recursively replace all of the nested sub-attributes and sub-types using the
- /// provided map functions. Returns nullptr in the case of failure. See
- /// `AttrTypeReplacer` for information on the support replacement function types.
- template <typename... ReplacementFns>
- }] # attrOrType # [{ replaceSubElements(ReplacementFns &&... replacementFns) {
- AttrTypeReplacer replacer;
- (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)), ...);
- return replacer.replace(*this);
- }
- }];
- code extraTraitClassDeclaration = [{
- /// Walk all of the held sub-attributes and sub-types.
- void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
- llvm::function_ref<void(mlir::Type)> walkTypesFn) {
- }] # interfaceName # " interface(" # derivedValue # [{);
- interface.walkSubElements(walkAttrsFn, walkTypesFn);
- }
-
- /// Recursively replace all of the nested sub-attributes and sub-types using the
- /// provided map functions. Returns nullptr in the case of failure. See
- /// `AttrTypeReplacer` for information on the support replacement function types.
- template <typename... ReplacementFns>
- }] # attrOrType # [{ replaceSubElements(ReplacementFns &&... replacementFns) {
- AttrTypeReplacer replacer;
- (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)), ...);
- return replacer.replace(}] # derivedValue # [{);
- }
- }];
- code extraSharedClassDeclaration = [{
- /// Walk all of the held sub-attributes.
- void walkSubAttrs(llvm::function_ref<void(mlir::Attribute)> walkFn) {
- walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {});
- }
- /// Walk all of the held sub-types.
- void walkSubTypes(llvm::function_ref<void(mlir::Type)> walkFn) {
- walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn);
- }
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// SubElementAttrInterface
-//===----------------------------------------------------------------------===//
-
-def SubElementAttrInterface
- : AttrInterface<"SubElementAttrInterface">,
- SubElementInterfaceBase<"SubElementAttrInterface", "::mlir::Attribute",
- "$_attr"> {
- let description = [{
- An interface used to query and manipulate sub-elements, such as sub-types
- and sub-attributes of a composite attribute.
-
- To support the introspection of custom parameters that hold sub-elements,
- a specialization of the `AttrTypeSubElementHandler` class must be provided.
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// SubElementTypeInterface
-//===----------------------------------------------------------------------===//
-
-def SubElementTypeInterface
- : TypeInterface<"SubElementTypeInterface">,
- SubElementInterfaceBase<"SubElementTypeInterface", "::mlir::Type",
- "$_type"> {
- let description = [{
- An interface used to query and manipulate sub-elements, such as sub-types
- and sub-attributes of a composite type.
-
- To support the introspection of custom parameters that hold sub-elements,
- a specialization of the `AttrTypeSubElementHandler` class must be provided.
- }];
-}
-
-#endif // MLIR_IR_SUBELEMENTINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
index b8215a80cb6f..37a7c242ee84 100644
--- a/mlir/include/mlir/IR/TypeRange.h
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -166,13 +166,13 @@ inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
}
//===----------------------------------------------------------------------===//
-// SubElementInterfaces
+// SubElements
//===----------------------------------------------------------------------===//
/// Enable TypeRange to be introspected for sub-elements.
template <>
struct AttrTypeSubElementHandler<TypeRange> {
- static void walk(TypeRange param, AttrTypeSubElementWalker &walker) {
+ static void walk(TypeRange param, AttrTypeImmediateSubElementWalker &walker) {
walker.walkRange(param);
}
static TypeRange replace(TypeRange param,
diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index 2bb28fd261eb..2aa6b1a59e86 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -30,6 +30,10 @@ class MLIRContext;
class AbstractType {
public:
using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
+ using WalkImmediateSubElementsFn = function_ref<void(
+ Type, function_ref<void(Attribute)>, function_ref<void(Type)>)>;
+ using ReplaceImmediateSubElementsFn =
+ function_ref<Type(Type, ArrayRef<Attribute>, ArrayRef<Type>)>;
/// Look up the specified abstract type in the MLIRContext and return a
/// reference to it.
@@ -40,17 +44,23 @@ class AbstractType {
template <typename T>
static AbstractType get(Dialect &dialect) {
return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
- T::getTypeID());
+ T::getWalkImmediateSubElementsFn(),
+ T::getReplaceImmediateSubElementsFn(), T::getTypeID());
}
/// This method is used by Dialect objects to register types with
/// custom TypeIDs.
/// The use of this method is in general discouraged in favor of
/// 'get<CustomType>(dialect)';
- static AbstractType get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
- HasTraitFn &&hasTrait, TypeID typeID) {
+ static AbstractType
+ get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
+ HasTraitFn &&hasTrait,
+ WalkImmediateSubElementsFn walkImmediateSubElementsFn,
+ ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
+ TypeID typeID) {
return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait),
- typeID);
+ walkImmediateSubElementsFn,
+ replaceImmediateSubElementsFn, typeID);
}
/// Return the dialect this type was registered to.
@@ -78,14 +88,29 @@ class AbstractType {
/// Returns true if the type has a particular trait.
bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
+ /// Walk the immediate sub-elements of the given type.
+ void walkImmediateSubElements(Type type,
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
+
+ /// Replace the immediate sub-elements of the given type.
+ Type replaceImmediateSubElements(Type type, ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const;
+
/// Return the unique identifier representing the concrete type class.
TypeID getTypeID() const { return typeID; }
private:
AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
- HasTraitFn &&hasTrait, TypeID typeID)
+ HasTraitFn &&hasTrait,
+ WalkImmediateSubElementsFn walkImmediateSubElementsFn,
+ ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
+ TypeID typeID)
: dialect(dialect), interfaceMap(std::move(interfaceMap)),
- hasTraitFn(std::move(hasTrait)), typeID(typeID) {}
+ hasTraitFn(std::move(hasTrait)),
+ walkImmediateSubElementsFn(walkImmediateSubElementsFn),
+ replaceImmediateSubElementsFn(replaceImmediateSubElementsFn),
+ typeID(typeID) {}
/// Give StorageUserBase access to the mutable lookup.
template <typename ConcreteT, typename BaseT, typename StorageT,
@@ -106,6 +131,12 @@ class AbstractType {
/// Function to check if the type has a particular trait.
HasTraitFn hasTraitFn;
+ /// Function to walk the immediate sub-elements of this type.
+ WalkImmediateSubElementsFn walkImmediateSubElementsFn;
+
+ /// Function to replace the immediate sub-elements of this type.
+ ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn;
+
/// The unique identifier of the derived Type class.
const TypeID typeID;
};
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 9d64a77742ef..cc7ecfb9d468 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -186,11 +186,52 @@ class Type {
}
/// Return the abstract type descriptor for this type.
- const AbstractTy &getAbstractType() { return impl->getAbstractType(); }
+ const AbstractTy &getAbstractType() const { return impl->getAbstractType(); }
/// Return the Type implementation.
ImplType *getImpl() const { return impl; }
+ /// Walk all of the immediately nested sub-attributes and sub-types. This
+ /// method does not recurse into sub elements.
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ getAbstractType().walkImmediateSubElements(*this, walkAttrsFn, walkTypesFn);
+ }
+
+ /// Replace the immediately nested sub-attributes and sub-types with those
+ /// provided. The order of the provided elements is derived from the order of
+ /// the elements returned by the callbacks of `walkImmediateSubElements`. The
+ /// element at index 0 would replace the very first attribute given by
+ /// `walkImmediateSubElements`. On success, the new instance with the values
+ /// replaced is returned. If replacement fails, nullptr is returned.
+ auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const {
+ return getAbstractType().replaceImmediateSubElements(*this, replAttrs,
+ replTypes);
+ }
+
+ /// Walk this type and all attibutes/types nested within using the
+ /// provided walk functions. See `AttrTypeWalker` for information on the
+ /// supported walk function types.
+ template <WalkOrder Order = WalkOrder::PostOrder, typename... WalkFns>
+ auto walk(WalkFns &&...walkFns) {
+ AttrTypeWalker walker;
+ (walker.addWalk(std::forward<WalkFns>(walkFns)), ...);
+ return walker.walk<Order>(*this);
+ }
+
+ /// Recursively replace all of the nested sub-attributes and sub-types using
+ /// the provided map functions. Returns nullptr in the case of failure. See
+ /// `AttrTypeReplacer` for information on the support replacement function
+ /// types.
+ template <typename... ReplacementFns>
+ auto replace(ReplacementFns &&...replacementFns) {
+ AttrTypeReplacer replacer;
+ (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)),
+ ...);
+ return replacer.replace(*this);
+ }
+
protected:
ImplType *impl{nullptr};
};
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index b170c2655f70..51e79339b7bb 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -35,7 +35,7 @@ class WalkResult {
enum ResultEnum { Interrupt, Advance, Skip } result;
public:
- WalkResult(ResultEnum result) : result(result) {}
+ WalkResult(ResultEnum result = Advance) : result(result) {}
/// Allow LogicalResult to interrupt the walk on failure.
WalkResult(LogicalResult result)
diff --git a/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp b/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
index 9579111bc618..1d292c5fa45e 100644
--- a/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
@@ -80,20 +80,15 @@ IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
void mlir::gpu::populateMemorySpaceAttributeTypeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
- typeConverter.addConversion([mapping](Type type) -> std::optional<Type> {
- auto subElementType = type.dyn_cast_or_null<SubElementTypeInterface>();
- if (!subElementType)
- return type;
- Type newType = subElementType.replaceSubElements(
- [mapping](Attribute attr) -> std::optional<Attribute> {
- auto memorySpaceAttr = attr.dyn_cast_or_null<gpu::AddressSpaceAttr>();
- if (!memorySpaceAttr)
- return std::nullopt;
- auto newValue = wrapNumericMemorySpace(
- attr.getContext(), mapping(memorySpaceAttr.getValue()));
- return newValue;
- });
- return newType;
+ typeConverter.addConversion([mapping](Type type) {
+ return type.replace([mapping](Attribute attr) -> std::optional<Attribute> {
+ auto memorySpaceAttr = attr.dyn_cast_or_null<gpu::AddressSpaceAttr>();
+ if (!memorySpaceAttr)
+ return std::nullopt;
+ auto newValue = wrapNumericMemorySpace(
+ attr.getContext(), mapping(memorySpaceAttr.getValue()));
+ return newValue;
+ });
});
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 0981b3b1cbe2..af5cb6f354a2 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -25,7 +25,6 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
-#include "mlir/IR/SubElementInterfaces.h"
#include "mlir/IR/Verifier.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
@@ -841,13 +840,11 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
}
// For most builtin types, we can simply walk the sub elements.
- if (auto subElementInterface = dyn_cast<SubElementTypeInterface>(type)) {
- auto visitFn = [&](auto element) {
- if (element)
- (void)printAlias(element);
- };
- subElementInterface.walkImmediateSubElements(visitFn, visitFn);
- }
+ auto visitFn = [&](auto element) {
+ if (element)
+ (void)printAlias(element);
+ };
+ type.walkImmediateSubElements(visitFn, visitFn);
}
/// Consider the given type to be printed for an alias.
diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/AttrTypeSubElements.cpp
similarity index 53%
rename from mlir/lib/IR/SubElementInterfaces.cpp
rename to mlir/lib/IR/AttrTypeSubElements.cpp
index 528e0cadfa7c..79b04966be6e 100644
--- a/mlir/lib/IR/SubElementInterfaces.cpp
+++ b/mlir/lib/IR/AttrTypeSubElements.cpp
@@ -1,4 +1,4 @@
-//===- SubElementInterfaces.cpp - Attr and Type SubElement Interfaces -----===//
+//===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,96 +6,77 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/IR/SubElementInterfaces.h"
#include "mlir/IR/Operation.h"
-
-#include "llvm/ADT/DenseSet.h"
#include <optional>
using namespace mlir;
//===----------------------------------------------------------------------===//
-// SubElementInterface
+// AttrTypeWalker
//===----------------------------------------------------------------------===//
-//===----------------------------------------------------------------------===//
-// WalkSubElements
-
-template <typename InterfaceT>
-static void walkSubElementsImpl(InterfaceT interface,
- function_ref<void(Attribute)> walkAttrsFn,
- function_ref<void(Type)> walkTypesFn,
- DenseSet<Attribute> &visitedAttrs,
- DenseSet<Type> &visitedTypes) {
- interface.walkImmediateSubElements(
- [&](Attribute attr) {
- // Guard against potentially null inputs. This removes the need for the
- // derived attribute/type to do it.
- if (!attr)
- return;
-
- // Avoid infinite recursion when visiting sub attributes later, if this
- // is a mutable attribute.
- if (LLVM_UNLIKELY(attr.hasTrait<AttributeTrait::IsMutable>())) {
- if (!visitedAttrs.insert(attr).second)
- return;
- }
-
- // Walk any sub elements first.
- if (auto interface = attr.dyn_cast<SubElementAttrInterface>())
- walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
- visitedTypes);
+WalkResult AttrTypeWalker::walkImpl(Attribute attr, WalkOrder order) {
+ return walkImpl(attr, attrWalkFns, order);
+}
+WalkResult AttrTypeWalker::walkImpl(Type type, WalkOrder order) {
+ return walkImpl(type, typeWalkFns, order);
+}
- // Walk this attribute.
- walkAttrsFn(attr);
- },
- [&](Type type) {
- // Guard against potentially null inputs. This removes the need for the
- // derived attribute/type to do it.
- if (!type)
- return;
-
- // Avoid infinite recursion when visiting sub types later, if this
- // is a mutable type.
- if (LLVM_UNLIKELY(type.hasTrait<TypeTrait::IsMutable>())) {
- if (!visitedTypes.insert(type).second)
- return;
- }
+template <typename T, typename WalkFns>
+WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns,
+ WalkOrder order) {
+ // Check if we've already walk this element before.
+ auto key = std::make_pair(element.getAsOpaquePointer(), (int)order);
+ auto it = visitedAttrTypes.find(key);
+ if (it != visitedAttrTypes.end())
+ return it->second;
+ visitedAttrTypes.try_emplace(key, WalkResult::advance());
- // Walk any sub elements first.
- if (auto interface = type.dyn_cast<SubElementTypeInterface>())
- walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
- visitedTypes);
+ // If we are walking in post order, walk the sub elements first.
+ if (order == WalkOrder::PostOrder) {
+ if (walkSubElements(element, order).wasInterrupted())
+ return visitedAttrTypes[key] = WalkResult::interrupt();
+ }
- // Walk this type.
- walkTypesFn(type);
- });
-}
+ // Walk this element, bailing if skipped or interrupted.
+ for (auto &walkFn : llvm::reverse(walkFns)) {
+ WalkResult walkResult = walkFn(element);
+ if (walkResult.wasInterrupted())
+ return visitedAttrTypes[key] = WalkResult::interrupt();
+ if (walkResult.wasSkipped())
+ return WalkResult::advance();
+ }
-void SubElementAttrInterface::walkSubElements(
- function_ref<void(Attribute)> walkAttrsFn,
- function_ref<void(Type)> walkTypesFn) {
- assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
- DenseSet<Attribute> visitedAttrs;
- DenseSet<Type> visitedTypes;
- walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
- visitedTypes);
+ // If we are walking in pre-order, walk the sub elements last.
+ if (order == WalkOrder::PreOrder) {
+ if (walkSubElements(element, order).wasInterrupted())
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
}
-void SubElementTypeInterface::walkSubElements(
- function_ref<void(Attribute)> walkAttrsFn,
- function_ref<void(Type)> walkTypesFn) {
- assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
- DenseSet<Attribute> visitedAttrs;
- DenseSet<Type> visitedTypes;
- walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
- visitedTypes);
+template <typename T>
+WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
+ WalkResult result = WalkResult::advance();
+ auto walkFn = [&](auto element) {
+ if (element && !result.wasInterrupted())
+ result = walkImpl(element, order);
+ };
+ interface.walkImmediateSubElements(walkFn, walkFn);
+ return result.wasInterrupted() ? result : WalkResult::advance();
}
//===----------------------------------------------------------------------===//
/// AttrTypeReplacer
//===----------------------------------------------------------------------===//
+void AttrTypeReplacer::addReplacement(ReplaceFn<Attribute> fn) {
+ attrReplacementFns.emplace_back(std::move(fn));
+}
+void AttrTypeReplacer::addReplacement(ReplaceFn<Type> fn) {
+ typeReplacementFns.push_back(std::move(fn));
+}
+
void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
bool replaceLocs, bool replaceTypes) {
// Functor that replaces the given element if the new value is
diff erent,
@@ -157,7 +138,6 @@ void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op,
template <typename T>
static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
- DenseMap<T, T> &elementMap,
SmallVectorImpl<T> &newElements,
FailureOr<bool> &changed) {
// Bail early if we failed at any point.
@@ -180,19 +160,18 @@ static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
}
}
-template <typename InterfaceT, typename T>
-T AttrTypeReplacer::replaceSubElements(InterfaceT interface,
- DenseMap<T, T> &interfaceMap) {
+template <typename T>
+T AttrTypeReplacer::replaceSubElements(T interface) {
// Walk the current sub-elements, replacing them as necessary.
SmallVector<Attribute, 16> newAttrs;
SmallVector<Type, 16> newTypes;
FailureOr<bool> changed = false;
interface.walkImmediateSubElements(
[&](Attribute element) {
- updateSubElementImpl(element, *this, attrMap, newAttrs, changed);
+ updateSubElementImpl(element, *this, newAttrs, changed);
},
[&](Type element) {
- updateSubElementImpl(element, *this, typeMap, newTypes, changed);
+ updateSubElementImpl(element, *this, newTypes, changed);
});
if (failed(changed))
return nullptr;
@@ -205,12 +184,12 @@ T AttrTypeReplacer::replaceSubElements(InterfaceT interface,
}
/// Shared implementation of replacing a given attribute or type element.
-template <typename InterfaceT, typename ReplaceFns, typename T>
-T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns,
- DenseMap<T, T> &map) {
- auto [it, inserted] = map.try_emplace(element, element);
+template <typename T, typename ReplaceFns>
+T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
+ const void *opaqueElement = element.getAsOpaquePointer();
+ auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement);
if (!inserted)
- return it->second;
+ return T::getFromOpaquePointer(it->second);
T result = element;
WalkResult walkResult = WalkResult::advance();
@@ -222,34 +201,42 @@ T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns,
}
// If an error occurred, return nullptr to indicate failure.
- if (walkResult.wasInterrupted() || !result)
- return map[element] = nullptr;
+ if (walkResult.wasInterrupted() || !result) {
+ attrTypeMap[opaqueElement] = nullptr;
+ return nullptr;
+ }
// Handle replacing sub-elements if this element is also a container.
if (!walkResult.wasSkipped()) {
- if (auto interface = dyn_cast<InterfaceT>(result)) {
- // Replace the sub elements of this element, bailing if we fail.
- if (!(result = replaceSubElements(interface, map)))
- return map[element] = nullptr;
+ // Replace the sub elements of this element, bailing if we fail.
+ if (!(result = replaceSubElements(result))) {
+ attrTypeMap[opaqueElement] = nullptr;
+ return nullptr;
}
}
- return map[element] = result;
+ attrTypeMap[opaqueElement] = result.getAsOpaquePointer();
+ return result;
}
Attribute AttrTypeReplacer::replace(Attribute attr) {
- return replaceImpl<SubElementAttrInterface>(attr, attrReplacementFns,
- attrMap);
+ return replaceImpl(attr, attrReplacementFns);
}
Type AttrTypeReplacer::replace(Type type) {
- return replaceImpl<SubElementTypeInterface>(type, typeReplacementFns,
- typeMap);
+ return replaceImpl(type, typeReplacementFns);
}
//===----------------------------------------------------------------------===//
-// SubElementInterface Tablegen definitions
+// AttrTypeImmediateSubElementWalker
//===----------------------------------------------------------------------===//
-#include "mlir/IR/SubElementAttrInterfaces.cpp.inc"
-#include "mlir/IR/SubElementTypeInterfaces.cpp.inc"
+void AttrTypeImmediateSubElementWalker::walk(Attribute element) {
+ if (element)
+ walkAttrsFn(element);
+}
+
+void AttrTypeImmediateSubElementWalker::walk(Type element) {
+ if (element)
+ walkTypesFn(element);
+}
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 4e585da7d909..2798944c2df3 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -12,6 +12,23 @@
using namespace mlir;
using namespace mlir::detail;
+//===----------------------------------------------------------------------===//
+// AbstractAttribute
+//===----------------------------------------------------------------------===//
+
+void AbstractAttribute::walkImmediateSubElements(
+ Attribute attr, function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ walkImmediateSubElementsFn(attr, walkAttrsFn, walkTypesFn);
+}
+
+Attribute
+AbstractAttribute::replaceImmediateSubElements(Attribute attr,
+ ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const {
+ return replaceImmediateSubElementsFn(attr, replAttrs, replTypes);
+}
+
//===----------------------------------------------------------------------===//
// Attribute
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 355ddd4d450a..8b4fb42e03ea 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRIR
AffineMap.cpp
AsmPrinter.cpp
Attributes.cpp
+ AttrTypeSubElements.cpp
Block.cpp
Builders.cpp
BuiltinAttributeInterfaces.cpp
@@ -26,7 +27,6 @@ add_mlir_library(MLIRIR
PatternMatch.cpp
Region.cpp
RegionKindInterface.cpp
- SubElementInterfaces.cpp
SymbolTable.cpp
TensorEncoding.cpp
Types.cpp
@@ -54,7 +54,6 @@ add_mlir_library(MLIRIR
MLIROpAsmInterfaceIncGen
MLIRRegionKindInterfaceIncGen
MLIRSideEffectInterfacesIncGen
- MLIRSubElementInterfacesIncGen
MLIRSymbolInterfacesIncGen
MLIRTensorEncodingIncGen
diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp
index 701683fdcb5c..00849857b51d 100644
--- a/mlir/lib/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/IR/ExtensibleDialect.cpp
@@ -407,9 +407,10 @@ void ExtensibleDialect::registerDynamicType(
assert(registered &&
"Trying to create a new dynamic type with an existing name");
- auto abstractType =
- AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(),
- DynamicType::getHasTraitFn(), typeID);
+ auto abstractType = AbstractType::get(
+ *dialect, DynamicAttr::getInterfaceMap(), DynamicType::getHasTraitFn(),
+ DynamicType::getWalkImmediateSubElementsFn(),
+ DynamicType::getReplaceImmediateSubElementsFn(), typeID);
/// Add the type to the dialect and the type uniquer.
addType(typeID, std::move(abstractType));
@@ -436,9 +437,10 @@ void ExtensibleDialect::registerDynamicAttr(
assert(registered &&
"Trying to create a new dynamic attribute with an existing name");
- auto abstractAttr =
- AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(),
- DynamicAttr::getHasTraitFn(), typeID);
+ auto abstractAttr = AbstractAttribute::get(
+ *dialect, DynamicAttr::getInterfaceMap(), DynamicAttr::getHasTraitFn(),
+ DynamicAttr::getWalkImmediateSubElementsFn(),
+ DynamicAttr::getReplaceImmediateSubElementsFn(), typeID);
/// Add the type to the dialect and the type uniquer.
addAttribute(typeID, std::move(abstractAttr));
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 4c3b3bb8be5a..446cce3db413 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -485,66 +485,14 @@ LogicalResult detail::verifySymbol(Operation *op) {
static WalkResult
walkSymbolRefs(Operation *op,
function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
- // Check to see if the operation has any attributes.
- DictionaryAttr attrDict = op->getAttrDictionary();
- if (attrDict.empty())
- return WalkResult::advance();
-
- // A worklist of a container attribute and the current index into the held
- // attribute list.
- struct WorklistItem {
- SubElementAttrInterface container;
- SmallVector<Attribute> immediateSubElements;
-
- explicit WorklistItem(SubElementAttrInterface container) {
- SmallVector<Attribute> subElements;
- container.walkImmediateSubElements(
- [&](Attribute attr) { subElements.push_back(attr); }, [](Type) {});
- immediateSubElements = std::move(subElements);
- }
- };
-
- SmallVector<WorklistItem, 1> attrWorklist(1, WorklistItem(attrDict));
- SmallVector<int, 1> curAccessChain(1, /*Value=*/-1);
-
- // Process the symbol references within the given nested attribute range.
- auto processAttrs = [&](int &index,
- WorklistItem &worklistItem) -> WalkResult {
- for (Attribute attr :
- llvm::drop_begin(worklistItem.immediateSubElements, index)) {
- // Invoke the provided callback if we find a symbol use and check for a
- // requested interrupt.
- if (auto symbolRef = attr.dyn_cast<SymbolRefAttr>()) {
+ return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
+ [&](SymbolRefAttr symbolRef) {
if (callback({op, symbolRef}).wasInterrupted())
return WalkResult::interrupt();
- /// Check for a nested container attribute, these will also need to be
- /// walked.
- } else if (auto interface = attr.dyn_cast<SubElementAttrInterface>()) {
- attrWorklist.emplace_back(interface);
- curAccessChain.push_back(-1);
- return WalkResult::advance();
- }
- // Make sure to keep the index counter in sync.
- ++index;
- }
-
- // Pop this container attribute from the worklist.
- attrWorklist.pop_back();
- curAccessChain.pop_back();
- return WalkResult::advance();
- };
-
- WalkResult result = WalkResult::advance();
- do {
- WorklistItem &item = attrWorklist.back();
- int &index = curAccessChain.back();
- ++index;
-
- // Process the given attribute, which is guaranteed to be a container.
- result = processAttrs(index, item);
- } while (!attrWorklist.empty() && !result.wasInterrupted());
- return result;
+ // Don't walk nested references.
+ return WalkResult::skip();
+ });
}
/// Walk all of the uses, for any symbol, that are nested within the given
diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 9dc8e6380c79..1d65fccb82b8 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -31,7 +31,7 @@ struct IntegerTypeStorage : public TypeStorage {
: width(width), signedness(signedness) {}
/// The hash key used for uniquing.
- using KeyTy = std::pair<unsigned, IntegerType::SignednessSemantics>;
+ using KeyTy = std::tuple<unsigned, IntegerType::SignednessSemantics>;
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
@@ -44,7 +44,7 @@ struct IntegerTypeStorage : public TypeStorage {
static IntegerTypeStorage *construct(TypeStorageAllocator &allocator,
KeyTy key) {
return new (allocator.allocate<IntegerTypeStorage>())
- IntegerTypeStorage(key.first, key.second);
+ IntegerTypeStorage(std::get<0>(key), std::get<1>(key));
}
KeyTy getAsKey() const { return KeyTy(width, signedness); }
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 670974bbf837..070ed4b14686 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -12,6 +12,22 @@
using namespace mlir;
using namespace mlir::detail;
+//===----------------------------------------------------------------------===//
+// AbstractType
+//===----------------------------------------------------------------------===//
+
+void AbstractType::walkImmediateSubElements(
+ Type type, function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ walkImmediateSubElementsFn(type, walkAttrsFn, walkTypesFn);
+}
+
+Type AbstractType::replaceImmediateSubElements(Type type,
+ ArrayRef<Attribute> replAttrs,
+ ArrayRef<Type> replTypes) const {
+ return replaceImmediateSubElementsFn(type, replAttrs, replTypes);
+}
+
//===----------------------------------------------------------------------===//
// Type
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/test-symbol-rauw.mlir b/mlir/test/IR/test-symbol-rauw.mlir
index c7d48b6c4eb1..fe1c80e8f400 100644
--- a/mlir/test/IR/test-symbol-rauw.mlir
+++ b/mlir/test/IR/test-symbol-rauw.mlir
@@ -76,7 +76,7 @@ module {
// -----
-// Check that replacement works in any implementations of SubElementsAttrInterface
+// Check that replacement works in any implementations of SubElements.
module {
// CHECK: func private @replaced_foo
func.func private @symbol_foo() attributes {sym.new_name = "replaced_foo" }
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index f28afe4daf05..bd7c050f9e85 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -21,7 +21,6 @@ include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpAsmInterface.td"
-include "mlir/IR/SubElementInterfaces.td"
// All of the attributes will extend this class.
class Test_Attr<string name, list<Trait> traits = []>
@@ -120,9 +119,7 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
let hasCustomAssemblyFormat = 1;
}
-def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
- SubElementAttrInterface
- ]> {
+def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess"> {
let mnemonic = "sub_elements_access";
let parameters = (ins
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index b6a7e275781a..c6c8f58ea552 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -22,7 +22,6 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Operation.h"
-#include "mlir/IR/SubElementInterfaces.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
@@ -132,7 +131,6 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
class TestRecursiveType
: public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
TestRecursiveTypeStorage,
- ::mlir::SubElementTypeInterface::Trait,
::mlir::TypeTrait::IsMutable> {
public:
using Base::Base;
diff --git a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
index 295c58f925e8..f39093498e09 100644
--- a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
+++ b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
@@ -8,7 +8,6 @@
#include "LLVMTestBase.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/IR/SubElementInterfaces.h"
using namespace mlir;
using namespace mlir::LLVM;
@@ -31,33 +30,24 @@ TEST_F(LLVMIRTest, MutualReferencedSubElementTypes) {
Type barBody[] = {LLVMPointerType::get(fooStructTy)};
ASSERT_TRUE(succeeded(barStructTy.setBody(barBody, /*isPacked=*/false)));
- auto subElementInterface = fooStructTy.dyn_cast<SubElementTypeInterface>();
- ASSERT_TRUE(bool(subElementInterface));
// Test if walkSubElements goes into infinite loops.
SmallVector<Type, 4> subElementTypes;
- subElementInterface.walkSubElements(
- [](Attribute attr) {},
- [&](Type type) { subElementTypes.push_back(type); });
- // We don't record LLVMPointerType (because it's immutable), thus
- // !llvm.ptr<struct<"bar",...>> will be visited twice.
- ASSERT_EQ(subElementTypes.size(), 5U);
+ fooStructTy.walk([&](Type type) { subElementTypes.push_back(type); });
+ ASSERT_EQ(subElementTypes.size(), 4U);
- // !llvm.ptr<struct<"bar",...>>
+ // !llvm.ptr<struct<"foo",...>>
ASSERT_TRUE(subElementTypes[0].isa<LLVMPointerType>());
- // !llvm.struct<"foo",...>
+ // !llvm.struct<"bar",...>
auto structType = subElementTypes[1].dyn_cast<LLVMStructType>();
ASSERT_TRUE(bool(structType));
- ASSERT_TRUE(structType.getName().equals("foo"));
+ ASSERT_TRUE(structType.getName().equals("bar"));
- // !llvm.ptr<struct<"foo",...>>
+ // !llvm.ptr<struct<"bar",...>>
ASSERT_TRUE(subElementTypes[2].isa<LLVMPointerType>());
- // !llvm.struct<"bar",...>
+ // !llvm.struct<"foo",...>
structType = subElementTypes[3].dyn_cast<LLVMStructType>();
ASSERT_TRUE(bool(structType));
- ASSERT_TRUE(structType.getName().equals("bar"));
-
- // !llvm.ptr<struct<"bar",...>>
- ASSERT_TRUE(subElementTypes[4].isa<LLVMPointerType>());
+ ASSERT_TRUE(structType.getName().equals("foo"));
}
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 13c7762563b1..7c0572ee89fe 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -422,4 +422,25 @@ TEST(SparseElementsAttrTest, GetZero) {
EXPECT_TRUE(zeroStringValue.getType() == stringTy);
}
+//===----------------------------------------------------------------------===//
+// SubElements
+//===----------------------------------------------------------------------===//
+
+TEST(SubElementTest, Nested) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ BoolAttr trueAttr = builder.getBoolAttr(true);
+ BoolAttr falseAttr = builder.getBoolAttr(false);
+ ArrayAttr boolArrayAttr = builder.getArrayAttr({trueAttr, falseAttr});
+ StringAttr strAttr = builder.getStringAttr("array");
+ DictionaryAttr dictAttr =
+ builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr));
+
+ SmallVector<Attribute> subAttrs;
+ dictAttr.walk([&](Attribute attr) { subAttrs.push_back(attr); });
+ EXPECT_EQ(llvm::ArrayRef(subAttrs),
+ ArrayRef<Attribute>(
+ {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
+}
} // namespace
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 9dcfb71744be..7d49283c59c3 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -8,7 +8,6 @@ add_mlir_unittest(MLIRIRTests
OperationSupportTest.cpp
PatternMatchTest.cpp
ShapedTypeTest.cpp
- SubElementInterfaceTest.cpp
TypeTest.cpp
DEPENDS
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index e77e8794d696..9b20d3c1219e 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -41,24 +41,6 @@ TEST(InterfaceTest, OpInterfaceDenseMapKey) {
EXPECT_FALSE(opSet.contains(op3));
}
-TEST(InterfaceTest, AttrInterfaceDenseMapKey) {
- MLIRContext context;
- context.loadDialect<test::TestDialect>();
-
- OpBuilder builder(&context);
-
- DenseSet<SubElementAttrInterface> attrSet;
- auto attr1 = builder.getArrayAttr({});
- auto attr2 = builder.getI32ArrayAttr({0});
- auto attr3 = builder.getI32ArrayAttr({1});
- attrSet.insert(attr1);
- attrSet.insert(attr2);
- attrSet.erase(attr1);
- EXPECT_FALSE(attrSet.contains(attr1));
- EXPECT_TRUE(attrSet.contains(attr2));
- EXPECT_FALSE(attrSet.contains(attr3));
-}
-
TEST(InterfaceTest, TypeInterfaceDenseMapKey) {
MLIRContext context;
context.loadDialect<test::TestDialect>();
diff --git a/mlir/unittests/IR/SubElementInterfaceTest.cpp b/mlir/unittests/IR/SubElementInterfaceTest.cpp
deleted file mode 100644
index ab461f4dc340..000000000000
--- a/mlir/unittests/IR/SubElementInterfaceTest.cpp
+++ /dev/null
@@ -1,36 +0,0 @@
-//===- SubElementInterfaceTest.cpp - SubElementInterface unit tests -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/SubElementInterfaces.h"
-#include "gtest/gtest.h"
-#include <cstdint>
-
-using namespace mlir;
-using namespace mlir::detail;
-
-namespace {
-TEST(SubElementInterfaceTest, Nested) {
- MLIRContext context;
- Builder builder(&context);
-
- BoolAttr trueAttr = builder.getBoolAttr(true);
- BoolAttr falseAttr = builder.getBoolAttr(false);
- ArrayAttr boolArrayAttr = builder.getArrayAttr({trueAttr, falseAttr});
- StringAttr strAttr = builder.getStringAttr("array");
- DictionaryAttr dictAttr =
- builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr));
-
- SmallVector<Attribute> subAttrs;
- dictAttr.walkSubAttrs([&](Attribute attr) { subAttrs.push_back(attr); });
- EXPECT_EQ(llvm::ArrayRef(subAttrs),
- ArrayRef<Attribute>({strAttr, trueAttr, falseAttr, boolArrayAttr}));
-}
-
-} // namespace
More information about the Mlir-commits
mailing list