[Mlir-commits] [mlir] 03d136c - [mlir] Promote the SubElementInterfaces to a core Attribute/Type construct

River Riddle llvmlistbot at llvm.org
Fri Jan 27 15:28:23 PST 2023


Author: River Riddle
Date: 2023-01-27T15:28:03-08:00
New Revision: 03d136cf5f3f10b618b7e17f897ebf6019518dcc

URL: https://github.com/llvm/llvm-project/commit/03d136cf5f3f10b618b7e17f897ebf6019518dcc
DIFF: https://github.com/llvm/llvm-project/commit/03d136cf5f3f10b618b7e17f897ebf6019518dcc.diff

LOG: [mlir] Promote the SubElementInterfaces to a core Attribute/Type construct

This commit restructures the sub element infrastructure to be a core part
of attributes and types, instead of being relegated to an interface. This
establishes sub element walking/replacement as something "always there",
which makes it easier to rely on for correctness/etc (which various bits of
infrastructure want, such as Symbols).

Attribute/Type now have `walk` and `replace` methods directly
accessible, which provide power API for interacting with sub elements. As
part of this, a new AttrTypeWalker class is introduced that supports caching
walked attributes/types, and a friendlier API (see the simplification of symbol
walking in SymbolTable.cpp).

Differential Revision: https://reviews.llvm.org/D142272

Added: 
    mlir/include/mlir/IR/AttrTypeSubElements.h
    mlir/lib/IR/AttrTypeSubElements.cpp

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
    mlir/include/mlir/IR/AttributeSupport.h
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/BuiltinLocationAttributes.td
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/CMakeLists.txt
    mlir/include/mlir/IR/Location.h
    mlir/include/mlir/IR/StorageUniquerSupport.h
    mlir/include/mlir/IR/TypeRange.h
    mlir/include/mlir/IR/TypeSupport.h
    mlir/include/mlir/IR/Types.h
    mlir/include/mlir/IR/Visitors.h
    mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Attributes.cpp
    mlir/lib/IR/CMakeLists.txt
    mlir/lib/IR/ExtensibleDialect.cpp
    mlir/lib/IR/SymbolTable.cpp
    mlir/lib/IR/TypeDetail.h
    mlir/lib/IR/Types.cpp
    mlir/test/IR/test-symbol-rauw.mlir
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestTypes.h
    mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
    mlir/unittests/IR/AttributeTest.cpp
    mlir/unittests/IR/CMakeLists.txt
    mlir/unittests/IR/InterfaceTest.cpp

Removed: 
    mlir/include/mlir/IR/SubElementInterfaces.h
    mlir/include/mlir/IR/SubElementInterfaces.td
    mlir/lib/IR/SubElementInterfaces.cpp
    mlir/unittests/IR/SubElementInterfaceTest.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 96936c4dcdc4..957f9eef908b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -12,7 +12,6 @@
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/Dialect/LLVMIR/LLVMEnums.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
-include "mlir/IR/SubElementInterfaces.td"
 
 // All of the attributes will extend this class.
 class LLVM_Attr<string name, string attrMnemonic,
@@ -160,9 +159,8 @@ def LLVM_DIBasicTypeAttr : LLVM_Attr<"DIBasicType", "di_basic_type",
 // DICompileUnitAttr
 //===----------------------------------------------------------------------===//
 
-def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit", [
-    SubElementAttrInterface
-  ], "DIScopeAttr"> {
+def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit",
+                                       /*traits=*/[], "DIScopeAttr"> {
   let parameters = (ins
     LLVM_DILanguageParameter:$sourceLanguage,
     "DIFileAttr":$file,
@@ -177,9 +175,8 @@ def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit", [
 // DICompositeTypeAttr
 //===----------------------------------------------------------------------===//
 
-def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type", [
-    SubElementAttrInterface
-  ], "DITypeAttr"> {
+def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
+                                         /*traits=*/[], "DITypeAttr"> {
   let parameters = (ins
     LLVM_DITagParameter:$tag,
     "StringAttr":$name,
@@ -199,9 +196,8 @@ def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
 // DIDerivedTypeAttr
 //===----------------------------------------------------------------------===//
 
-def LLVM_DIDerivedTypeAttr : LLVM_Attr<"DIDerivedType", "di_derived_type", [
-    SubElementAttrInterface
-  ], "DITypeAttr"> {
+def LLVM_DIDerivedTypeAttr : LLVM_Attr<"DIDerivedType", "di_derived_type",
+                                       /*traits=*/[], "DITypeAttr"> {
   let parameters = (ins
     LLVM_DITagParameter:$tag,
     OptionalParameter<"StringAttr">:$name,
@@ -231,9 +227,8 @@ def LLVM_DIFileAttr : LLVM_Attr<"DIFile", "di_file", /*traits=*/[], "DIScopeAttr
 // DILexicalBlockAttr
 //===----------------------------------------------------------------------===//
 
-def LLVM_DILexicalBlockAttr : LLVM_Attr<"DILexicalBlock", "di_lexical_block", [
-    SubElementAttrInterface
-  ], "DIScopeAttr"> {
+def LLVM_DILexicalBlockAttr : LLVM_Attr<"DILexicalBlock", "di_lexical_block",
+                                        /*traits=*/[], "DIScopeAttr"> {
   let parameters = (ins
     "DIScopeAttr":$scope,
     OptionalParameter<"DIFileAttr">:$file,
@@ -255,9 +250,8 @@ def LLVM_DILexicalBlockAttr : LLVM_Attr<"DILexicalBlock", "di_lexical_block", [
 // DILexicalBlockFileAttr
 //===----------------------------------------------------------------------===//
 
-def LLVM_DILexicalBlockFile : LLVM_Attr<"DILexicalBlockFile", "di_lexical_block_file", [
-    SubElementAttrInterface
-  ], "DIScopeAttr"> {
+def LLVM_DILexicalBlockFile : LLVM_Attr<"DILexicalBlockFile", "di_lexical_block_file",
+                                        /*traits=*/[], "DIScopeAttr"> {
   let parameters = (ins
     "DIScopeAttr":$scope,
     OptionalParameter<"DIFileAttr">:$file,
@@ -277,9 +271,8 @@ def LLVM_DILexicalBlockFile : LLVM_Attr<"DILexicalBlockFile", "di_lexical_block_
 // DILocalVariableAttr
 //===----------------------------------------------------------------------===//
 
-def LLVM_DILocalVariableAttr : LLVM_Attr<"DILocalVariable", "di_local_variable", [
-    SubElementAttrInterface
-  ], "DINodeAttr"> {
+def LLVM_DILocalVariableAttr : LLVM_Attr<"DILocalVariable", "di_local_variable",
+                                         /*traits=*/[], "DINodeAttr"> {
   let parameters = (ins
     "DIScopeAttr":$scope,
     "StringAttr":$name,
@@ -307,9 +300,8 @@ def LLVM_DILocalVariableAttr : LLVM_Attr<"DILocalVariable", "di_local_variable",
 // DISubprogramAttr
 //===----------------------------------------------------------------------===//
 
-def LLVM_DISubprogramAttr : LLVM_Attr<"DISubprogram", "di_subprogram", [
-    SubElementAttrInterface
-  ], "DIScopeAttr"> {
+def LLVM_DISubprogramAttr : LLVM_Attr<"DISubprogram", "di_subprogram",
+                                      /*traits=*/[], "DIScopeAttr"> {
   let parameters = (ins
     "DICompileUnitAttr":$compileUnit,
     "DIScopeAttr":$scope,
@@ -357,9 +349,8 @@ def LLVM_DISubrangeAttr : LLVM_Attr<"DISubrange", "di_subrange", /*traits=*/[],
 // DISubroutineTypeAttr
 //===----------------------------------------------------------------------===//
 
-def LLVM_DISubroutineTypeAttr : LLVM_Attr<"DISubroutineType", "di_subroutine_type", [
-    SubElementAttrInterface
-  ], "DITypeAttr"> {
+def LLVM_DISubroutineTypeAttr : LLVM_Attr<"DISubroutineType", "di_subroutine_type",
+                                          /*traits=*/[], "DITypeAttr"> {
   let parameters = (ins
     LLVM_DICallingConventionParameter:$callingConvention,
     OptionalArrayRefParameter<"DITypeAttr">:$types
@@ -377,9 +368,7 @@ def LLVM_DISubroutineTypeAttr : LLVM_Attr<"DISubroutineType", "di_subroutine_typ
 // MemoryEffectsAttr
 //===----------------------------------------------------------------------===//
 
-def LLVM_MemoryEffectsAttr : LLVM_Attr<"MemoryEffects", "memory_effects", [
-    SubElementAttrInterface
-  ]> {
+def LLVM_MemoryEffectsAttr : LLVM_Attr<"MemoryEffects", "memory_effects"> {
   let parameters = (ins
     "ModRefInfo":$other,
     "ModRefInfo":$argMem,

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index fcc35bdb856f..a44448e18b45 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -14,7 +14,6 @@
 #ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
 #define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
 
-#include "mlir/IR/SubElementInterfaces.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include <optional>
@@ -104,7 +103,6 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
 class LLVMStructType
     : public Type::TypeBase<LLVMStructType, Type, detail::LLVMStructTypeStorage,
                             DataLayoutTypeInterface::Trait,
-                            SubElementTypeInterface::Trait,
                             TypeTrait::IsMutable> {
 public:
   /// Inherit base constructors.

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 460e7ded84c3..caf4b58b87f5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -11,7 +11,6 @@
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/IR/AttrTypeBase.td"
-include "mlir/IR/SubElementInterfaces.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
 
 /// Base class for all LLVM dialect types.
@@ -25,8 +24,7 @@ class LLVMType<string typeName, string typeMnemonic, list<Trait> traits = []>
 //===----------------------------------------------------------------------===//
 
 def LLVMArrayType : LLVMType<"LLVMArray", "array", [
-    DeclareTypeInterfaceMethods<DataLayoutTypeInterface, ["getTypeSize"]>,
-    DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+    DeclareTypeInterfaceMethods<DataLayoutTypeInterface, ["getTypeSize"]>]> {
   let summary = "LLVM array type";
   let description = [{
     The `!llvm.array` type represents a fixed-size array of element types.
@@ -62,8 +60,7 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [
 // LLVMFunctionType
 //===----------------------------------------------------------------------===//
 
-def LLVMFunctionType : LLVMType<"LLVMFunction", "func", [
-    DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+def LLVMFunctionType : LLVMType<"LLVMFunction", "func"> {
   let summary = "LLVM function type";
   let description = [{
     The `!llvm.func` is a function type. It consists of a single return type
@@ -124,8 +121,7 @@ def LLVMFunctionType : LLVMType<"LLVMFunction", "func", [
 
 def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
     DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
-      "areCompatible", "verifyEntries"]>,
-    DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+      "areCompatible", "verifyEntries"]>]> {
   let summary = "LLVM pointer type";
   let description = [{
     The `!llvm.ptr` type is an LLVM pointer type. This type typically represents
@@ -171,8 +167,7 @@ def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
 // LLVMFixedVectorType
 //===----------------------------------------------------------------------===//
 
-def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec", [
-    DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec"> {
   let summary = "LLVM fixed vector type";
   let description = [{
     LLVM dialect scalable vector type, represents a sequence of elements of
@@ -202,8 +197,7 @@ def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec", [
 // LLVMScalableVectorType
 //===----------------------------------------------------------------------===//
 
-def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec", [
-    DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec"> {
   let summary = "LLVM scalable vector type";
   let description = [{
     LLVM dialect scalable vector type, represents a sequence of elements of

diff  --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/AttrTypeSubElements.h
similarity index 66%
rename from mlir/include/mlir/IR/SubElementInterfaces.h
rename to mlir/include/mlir/IR/AttrTypeSubElements.h
index 935d7fcd59cf..23b8d533b7cd 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.h
+++ b/mlir/include/mlir/IR/AttrTypeSubElements.h
@@ -1,4 +1,4 @@
-//===- SubElementInterfaces.h - Attr and Type SubElements -------*- C++ -*-===//
+//===- AttrTypeSubElements.h - Attr and Type SubElements -------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,20 +6,112 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file contains interfaces and utilities for querying the sub elements of
-// an attribute or type.
+// This file contains utilities for querying the sub elements of an attribute or
+// type.
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_IR_SUBELEMENTINTERFACES_H
-#define MLIR_IR_SUBELEMENTINTERFACES_H
+#ifndef MLIR_IR_ATTRTYPESUBELEMENTS_H
+#define MLIR_IR_ATTRTYPESUBELEMENTS_H
 
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Types.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Visitors.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
 #include <optional>
 
 namespace mlir {
+class Attribute;
+class Type;
+
+//===----------------------------------------------------------------------===//
+/// AttrTypeWalker
+//===----------------------------------------------------------------------===//
+
+/// This class provides a utility for walking attributes/types, and their sub
+/// elements. Multiple walk functions may be registered.
+class AttrTypeWalker {
+public:
+  //===--------------------------------------------------------------------===//
+  // Application
+  //===--------------------------------------------------------------------===//
+
+  /// Walk the given attribute/type, and recursively walk any sub elements.
+  template <WalkOrder Order, typename T>
+  WalkResult walk(T element) {
+    return walkImpl(element, Order);
+  }
+  template <typename T>
+  WalkResult walk(T element) {
+    return walk<WalkOrder::PostOrder, T>(element);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Registration
+  //===--------------------------------------------------------------------===//
+
+  template <typename T>
+  using WalkFn = std::function<WalkResult(T)>;
+
+  /// Register a walk function for a given attribute or type. A walk function
+  /// must be convertible to any of the following forms(where `T` is a class
+  /// derived from `Type` or `Attribute`:
+  ///
+  ///   * WalkResult(T)
+  ///     - Returns a walk result, which can be used to control the walk
+  ///
+  ///   * void(T)
+  ///     - Returns void, i.e. the walk always continues.
+  ///
+  /// Note: When walking, the mostly recently added walk functions will be
+  ///       invoked first.
+  void addWalk(WalkFn<Attribute> &&fn) {
+    attrWalkFns.emplace_back(std::move(fn));
+  }
+  void addWalk(WalkFn<Type> &&fn) { typeWalkFns.push_back(std::move(fn)); }
+
+  /// Register a replacement function that doesn't match the default signature,
+  /// either because it uses a derived parameter type, or it uses a simplified
+  /// result type.
+  template <typename FnT,
+            typename T = typename llvm::function_traits<
+                std::decay_t<FnT>>::template arg_t<0>,
+            typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
+                                                Attribute, Type>,
+            typename ResultT = std::invoke_result_t<FnT, T>>
+  std::enable_if_t<!std::is_same_v<T, BaseT> || std::is_same_v<ResultT, void>>
+  addWalk(FnT &&callback) {
+    addWalk([callback = std::forward<FnT>(callback)](BaseT base) -> WalkResult {
+      if (auto derived = dyn_cast<T>(base)) {
+        if constexpr (std::is_convertible_v<ResultT, WalkResult>)
+          return callback(derived);
+        else
+          callback(derived);
+      }
+      return WalkResult::advance();
+    });
+  }
+
+private:
+  WalkResult walkImpl(Attribute attr, WalkOrder order);
+  WalkResult walkImpl(Type type, WalkOrder order);
+
+  /// Internal implementation of the `walk` methods above.
+  template <typename T, typename WalkFns>
+  WalkResult walkImpl(T element, WalkFns &walkFns, WalkOrder order);
+
+  /// Walk the sub elements of the given interface.
+  template <typename T>
+  WalkResult walkSubElements(T interface, WalkOrder order);
+
+  /// The set of walk functions that map sub elements.
+  std::vector<WalkFn<Attribute>> attrWalkFns;
+  std::vector<WalkFn<Type>> typeWalkFns;
+
+  /// The set of visited attributes/types.
+  DenseMap<std::pair<const void *, int>, WalkResult> visitedAttrTypes;
+};
+
 //===----------------------------------------------------------------------===//
 /// AttrTypeReplacer
 //===----------------------------------------------------------------------===//
@@ -84,12 +176,8 @@ class AttrTypeReplacer {
   ///
   /// Note: When replacing, the mostly recently added replacement functions will
   ///       be invoked first.
-  void addReplacement(ReplaceFn<Attribute> fn) {
-    attrReplacementFns.emplace_back(std::move(fn));
-  }
-  void addReplacement(ReplaceFn<Type> fn) {
-    typeReplacementFns.push_back(std::move(fn));
-  }
+  void addReplacement(ReplaceFn<Attribute> fn);
+  void addReplacement(ReplaceFn<Type> fn);
 
   /// Register a replacement function that doesn't match the default signature,
   /// either because it uses a derived parameter type, or it uses a simplified
@@ -120,20 +208,19 @@ class AttrTypeReplacer {
 
 private:
   /// Internal implementation of the `replace` methods above.
-  template <typename InterfaceT, typename ReplaceFns, typename T>
-  T replaceImpl(T element, ReplaceFns &replaceFns, DenseMap<T, T> &map);
+  template <typename T, typename ReplaceFns>
+  T replaceImpl(T element, ReplaceFns &replaceFns);
 
   /// Replace the sub elements of the given interface.
-  template <typename InterfaceT, typename T = typename InterfaceT::ValueType>
-  T replaceSubElements(InterfaceT interface, DenseMap<T, T> &interfaceMap);
+  template <typename T>
+  T replaceSubElements(T interface);
 
   /// The set of replacement functions that map sub elements.
   std::vector<ReplaceFn<Attribute>> attrReplacementFns;
   std::vector<ReplaceFn<Type>> typeReplacementFns;
 
   /// The set of cached mappings for attributes/types.
-  DenseMap<Attribute, Attribute> attrMap;
-  DenseMap<Type, Type> typeMap;
+  DenseMap<const void *, const void *> attrTypeMap;
 };
 
 //===----------------------------------------------------------------------===//
@@ -142,22 +229,16 @@ class AttrTypeReplacer {
 
 /// This class is used by AttrTypeSubElementHandler instances to walking sub
 /// attributes and types.
-class AttrTypeSubElementWalker {
+class AttrTypeImmediateSubElementWalker {
 public:
-  AttrTypeSubElementWalker(function_ref<void(Attribute)> walkAttrsFn,
-                           function_ref<void(Type)> walkTypesFn)
+  AttrTypeImmediateSubElementWalker(function_ref<void(Attribute)> walkAttrsFn,
+                                    function_ref<void(Type)> walkTypesFn)
       : walkAttrsFn(walkAttrsFn), walkTypesFn(walkTypesFn) {}
 
   /// Walk an attribute.
-  void walk(Attribute element) {
-    if (element)
-      walkAttrsFn(element);
-  }
+  void walk(Attribute element);
   /// Walk a type.
-  void walk(Type element) {
-    if (element)
-      walkTypesFn(element);
-  }
+  void walk(Type element);
   /// Walk a range of attributes or types.
   template <typename RangeT>
   void walkRange(RangeT &&elements) {
@@ -212,7 +293,8 @@ using TypeSubElementReplacements = AttrTypeSubElementReplacements<Type>;
 template <typename T, typename Enable = void>
 struct AttrTypeSubElementHandler {
   /// Default walk implementation that does nothing.
-  static inline void walk(const T &param, AttrTypeSubElementWalker &walker) {}
+  static inline void walk(const T &param,
+                          AttrTypeImmediateSubElementWalker &walker) {}
 
   /// Default replace implementation just forwards the parameter.
   template <typename ParamT>
@@ -241,7 +323,7 @@ template <typename T>
 struct AttrTypeSubElementHandler<
     T, std::enable_if_t<std::is_base_of_v<Attribute, T> ||
                         std::is_base_of_v<Type, T>>> {
-  static void walk(T param, AttrTypeSubElementWalker &walker) {
+  static void walk(T param, AttrTypeImmediateSubElementWalker &walker) {
     walker.walk(param);
   }
   static T replace(T param, AttrSubElementReplacements &attrRepls,
@@ -255,27 +337,14 @@ struct AttrTypeSubElementHandler<
     }
   }
 };
-template <>
-struct AttrTypeSubElementHandler<NamedAttribute> {
-  template <typename T>
-  static void walk(T param, AttrTypeSubElementWalker &walker) {
-    walker.walk(param.getName());
-    walker.walk(param.getValue());
-  }
-  template <typename T>
-  static T replace(T param, AttrSubElementReplacements &attrRepls,
-                   TypeSubElementReplacements &typeRepls) {
-    ArrayRef<Attribute> paramRepls = attrRepls.take_front(2);
-    return T(cast<decltype(param.getName())>(paramRepls[0]), paramRepls[1]);
-  }
-};
 /// Implementation for derived ArrayRef.
 template <typename T>
 struct AttrTypeSubElementHandler<ArrayRef<T>,
                                  std::enable_if_t<has_sub_attr_or_type_v<T>>> {
   using EltHandler = AttrTypeSubElementHandler<T>;
 
-  static void walk(ArrayRef<T> param, AttrTypeSubElementWalker &walker) {
+  static void walk(ArrayRef<T> param,
+                   AttrTypeImmediateSubElementWalker &walker) {
     for (const T &subElement : param)
       EltHandler::walk(subElement, walker);
   }
@@ -283,11 +352,11 @@ struct AttrTypeSubElementHandler<ArrayRef<T>,
                       TypeSubElementReplacements &typeRepls) {
     // Normal attributes/types can extract using the replacer directly.
     if constexpr (std::is_base_of_v<Attribute, T> &&
-                  sizeof(T) == sizeof(Attribute)) {
+                  sizeof(T) == sizeof(void *)) {
       ArrayRef<Attribute> attrs = attrRepls.take_front(param.size());
       return ArrayRef<T>((const T *)attrs.data(), attrs.size());
     } else if constexpr (std::is_base_of_v<Type, T> &&
-                         sizeof(T) == sizeof(Type)) {
+                         sizeof(T) == sizeof(void *)) {
       ArrayRef<Type> types = typeRepls.take_front(param.size());
       return ArrayRef<T>((const T *)types.data(), types.size());
     } else {
@@ -305,7 +374,7 @@ template <typename... Ts>
 struct AttrTypeSubElementHandler<
     std::tuple<Ts...>, std::enable_if_t<has_sub_attr_or_type_v<Ts...>>> {
   static void walk(const std::tuple<Ts...> &param,
-                   AttrTypeSubElementWalker &walker) {
+                   AttrTypeImmediateSubElementWalker &walker) {
     std::apply(
         [&](const Ts &...params) {
           (AttrTypeSubElementHandler<Ts>::walk(params, walker), ...);
@@ -333,6 +402,8 @@ template <typename... Ts>
 struct is_tuple<std::tuple<Ts...>> : public std::true_type {};
 template <typename T, typename... Ts>
 using has_get_method = decltype(T::get(std::declval<Ts>()...));
+template <typename T, typename... Ts>
+using has_get_as_key = decltype(std::declval<T>().getAsKey());
 
 /// This function provides the underlying implementation for the
 /// SubElementInterface walk method, using the key type of the derived
@@ -341,21 +412,24 @@ template <typename T>
 void walkImmediateSubElementsImpl(T derived,
                                   function_ref<void(Attribute)> walkAttrsFn,
                                   function_ref<void(Type)> walkTypesFn) {
-  auto key = static_cast<typename T::ImplType *>(derived.getImpl())->getAsKey();
+  using ImplT = typename T::ImplType;
+  if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) {
+    auto key = static_cast<ImplT *>(derived.getImpl())->getAsKey();
 
-  // If we don't have any sub-elements, there is nothing to do.
-  if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
-    return;
-  } else {
-    AttrTypeSubElementWalker walker(walkAttrsFn, walkTypesFn);
-    AttrTypeSubElementHandler<decltype(key)>::walk(key, walker);
+    // If we don't have any sub-elements, there is nothing to do.
+    if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
+      return;
+    } else {
+      AttrTypeImmediateSubElementWalker walker(walkAttrsFn, walkTypesFn);
+      AttrTypeSubElementHandler<decltype(key)>::walk(key, walker);
+    }
   }
 }
 
 /// This function invokes the proper `get` method for  a type `T` with the given
 /// values.
 template <typename T, typename... Ts>
-T constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) {
+auto constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) {
   // Prefer a direct `get` method if one exists.
   if constexpr (llvm::is_detected<has_get_method, T, Ts...>::value) {
     (void)ctx;
@@ -373,38 +447,39 @@ T constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) {
 /// SubElementInterface replace method, using the key type of the derived
 /// attribute/type to interact with the individual parameters.
 template <typename T>
-T replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
-                                  ArrayRef<Type> &replTypes) {
-  auto key = static_cast<typename T::ImplType *>(derived.getImpl())->getAsKey();
+auto replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
+                                     ArrayRef<Type> &replTypes) {
+  using ImplT = typename T::ImplType;
+  if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) {
+    auto key = static_cast<ImplT *>(derived.getImpl())->getAsKey();
 
-  // If we don't have any sub-elements, we can just return the original.
-  if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
-    return derived;
+    // If we don't have any sub-elements, we can just return the original.
+    if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
+      return derived;
 
-    // Otherwise, we need to replace any necessary sub-elements.
-  } else {
-    AttrSubElementReplacements attrRepls(replAttrs);
-    TypeSubElementReplacements typeRepls(replTypes);
-    auto newKey = AttrTypeSubElementHandler<decltype(key)>::replace(
-        key, attrRepls, typeRepls);
-    if constexpr (is_tuple<decltype(key)>::value) {
-      return std::apply(
-          [&](auto &&...params) {
-            return constructSubElementReplacement<T>(
-                derived.getContext(),
-                std::forward<decltype(params)>(params)...);
-          },
-          newKey);
+      // Otherwise, we need to replace any necessary sub-elements.
     } else {
-      return constructSubElementReplacement<T>(derived.getContext(), newKey);
+      AttrSubElementReplacements attrRepls(replAttrs);
+      TypeSubElementReplacements typeRepls(replTypes);
+      auto newKey = AttrTypeSubElementHandler<decltype(key)>::replace(
+          key, attrRepls, typeRepls);
+      if constexpr (is_tuple<decltype(key)>::value) {
+        return std::apply(
+            [&](auto &&...params) {
+              return constructSubElementReplacement<T>(
+                  derived.getContext(),
+                  std::forward<decltype(params)>(params)...);
+            },
+            newKey);
+      } else {
+        return constructSubElementReplacement<T>(derived.getContext(), newKey);
+      }
     }
+  } else {
+    return derived;
   }
 }
 } // namespace detail
 } // namespace mlir
 
-/// Include the definitions of the sub element interfaces.
-#include "mlir/IR/SubElementAttrInterfaces.h.inc"
-#include "mlir/IR/SubElementTypeInterfaces.h.inc"
-
-#endif // MLIR_IR_SUBELEMENTINTERFACES_H
+#endif // MLIR_IR_ATTRTYPESUBELEMENTS_H

diff  --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index 73ede2e4d481..796adbe8531a 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -20,9 +20,6 @@
 #include "llvm/ADT/Twine.h"
 
 namespace mlir {
-class MLIRContext;
-class Type;
-
 //===----------------------------------------------------------------------===//
 // AbstractAttribute
 //===----------------------------------------------------------------------===//
@@ -32,6 +29,10 @@ class Type;
 class AbstractAttribute {
 public:
   using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
+  using WalkImmediateSubElementsFn = function_ref<void(
+      Attribute, function_ref<void(Attribute)>, function_ref<void(Type)>)>;
+  using ReplaceImmediateSubElementsFn =
+      function_ref<Attribute(Attribute, ArrayRef<Attribute>, ArrayRef<Type>)>;
 
   /// Look up the specified abstract attribute in the MLIRContext and return a
   /// reference to it.
@@ -42,6 +43,8 @@ class AbstractAttribute {
   template <typename T>
   static AbstractAttribute get(Dialect &dialect) {
     return AbstractAttribute(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
+                             T::getWalkImmediateSubElementsFn(),
+                             T::getReplaceImmediateSubElementsFn(),
                              T::getTypeID());
   }
 
@@ -49,11 +52,15 @@ class AbstractAttribute {
   /// custom TypeIDs.
   /// The use of this method is in general discouraged in favor of
   /// 'get<CustomAttribute>(dialect)'.
-  static AbstractAttribute get(Dialect &dialect,
-                               detail::InterfaceMap &&interfaceMap,
-                               HasTraitFn &&hasTrait, TypeID typeID) {
+  static AbstractAttribute
+  get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
+      HasTraitFn &&hasTrait,
+      WalkImmediateSubElementsFn walkImmediateSubElementsFn,
+      ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
+      TypeID typeID) {
     return AbstractAttribute(dialect, std::move(interfaceMap),
-                             std::move(hasTrait), typeID);
+                             std::move(hasTrait), walkImmediateSubElementsFn,
+                             replaceImmediateSubElementsFn, typeID);
   }
 
   /// Return the dialect this attribute was registered to.
@@ -82,14 +89,30 @@ class AbstractAttribute {
   /// Returns true if the attribute has a particular trait.
   bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
 
+  /// Walk the immediate sub-elements of this attribute.
+  void walkImmediateSubElements(Attribute attr,
+                                function_ref<void(Attribute)> walkAttrsFn,
+                                function_ref<void(Type)> walkTypesFn) const;
+
+  /// Replace the immediate sub-elements of this attribute.
+  Attribute replaceImmediateSubElements(Attribute attr,
+                                        ArrayRef<Attribute> replAttrs,
+                                        ArrayRef<Type> replTypes) const;
+
   /// Return the unique identifier representing the concrete attribute class.
   TypeID getTypeID() const { return typeID; }
 
 private:
   AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
-                    HasTraitFn &&hasTrait, TypeID typeID)
+                    HasTraitFn &&hasTraitFn,
+                    WalkImmediateSubElementsFn walkImmediateSubElementsFn,
+                    ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
+                    TypeID typeID)
       : dialect(dialect), interfaceMap(std::move(interfaceMap)),
-        hasTraitFn(std::move(hasTrait)), typeID(typeID) {}
+        hasTraitFn(std::move(hasTraitFn)),
+        walkImmediateSubElementsFn(walkImmediateSubElementsFn),
+        replaceImmediateSubElementsFn(replaceImmediateSubElementsFn),
+        typeID(typeID) {}
 
   /// Give StorageUserBase access to the mutable lookup.
   template <typename ConcreteT, typename BaseT, typename StorageT,
@@ -110,6 +133,12 @@ class AbstractAttribute {
   /// Function to check if the attribute has a particular trait.
   HasTraitFn hasTraitFn;
 
+  /// Function to walk the immediate sub-elements of this attribute.
+  WalkImmediateSubElementsFn walkImmediateSubElementsFn;
+
+  /// Function to replace the immediate sub-elements of this attribute.
+  ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn;
+
   /// The unique identifier of the derived Attribute class.
   const TypeID typeID;
 };

diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index a31560222ecf..a39ddf59df0a 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -101,6 +101,48 @@ class Attribute {
     return impl->getAbstractAttribute();
   }
 
+  /// Walk all of the immediately nested sub-attributes and sub-types. This
+  /// method does not recurse into sub elements.
+  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+                                function_ref<void(Type)> walkTypesFn) const {
+    getAbstractAttribute().walkImmediateSubElements(*this, walkAttrsFn,
+                                                    walkTypesFn);
+  }
+
+  /// Replace the immediately nested sub-attributes and sub-types with those
+  /// provided. The order of the provided elements is derived from the order of
+  /// the elements returned by the callbacks of `walkImmediateSubElements`. The
+  /// element at index 0 would replace the very first attribute given by
+  /// `walkImmediateSubElements`. On success, the new instance with the values
+  /// replaced is returned. If replacement fails, nullptr is returned.
+  auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                   ArrayRef<Type> replTypes) const {
+    return getAbstractAttribute().replaceImmediateSubElements(*this, replAttrs,
+                                                              replTypes);
+  }
+
+  /// Walk this attribute and all attibutes/types nested within using the
+  /// provided walk functions. See `AttrTypeWalker` for information on the
+  /// supported walk function types.
+  template <WalkOrder Order = WalkOrder::PostOrder, typename... WalkFns>
+  auto walk(WalkFns &&...walkFns) {
+    AttrTypeWalker walker;
+    (walker.addWalk(std::forward<WalkFns>(walkFns)), ...);
+    return walker.walk<Order>(*this);
+  }
+
+  /// Recursively replace all of the nested sub-attributes and sub-types using
+  /// the provided map functions. Returns nullptr in the case of failure. See
+  /// `AttrTypeReplacer` for information on the support replacement function
+  /// types.
+  template <typename... ReplacementFns>
+  auto replace(ReplacementFns &&...replacementFns) {
+    AttrTypeReplacer replacer;
+    (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)),
+     ...);
+    return replacer.replace(*this);
+  }
+
   /// Return the internal Attribute implementation.
   ImplType *getImpl() const { return impl; }
 
@@ -201,6 +243,22 @@ inline ::llvm::hash_code hash_value(const NamedAttribute &arg) {
   return DenseMapInfo<AttrPairT>::getHashValue(AttrPairT(arg.name, arg.value));
 }
 
+/// Allow walking and replacing the subelements of a NamedAttribute.
+template <>
+struct AttrTypeSubElementHandler<NamedAttribute> {
+  template <typename T>
+  static void walk(T param, AttrTypeImmediateSubElementWalker &walker) {
+    walker.walk(param.getName());
+    walker.walk(param.getValue());
+  }
+  template <typename T>
+  static T replace(T param, AttrSubElementReplacements &attrRepls,
+                   TypeSubElementReplacements &typeRepls) {
+    ArrayRef<Attribute> paramRepls = attrRepls.take_front(2);
+    return T(cast<decltype(param.getName())>(paramRepls[0]), paramRepls[1]);
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // AttributeTraitBase
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index a23da78a48d6..d8506cbcdad1 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -10,7 +10,6 @@
 #define MLIR_IR_BUILTINATTRIBUTES_H
 
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
-#include "mlir/IR/SubElementInterfaces.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Sequence.h"
 #include <complex>

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 0e7b5012b082..446bc375d3fa 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -18,7 +18,6 @@ include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinDialect.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
-include "mlir/IR/SubElementInterfaces.td"
 
 // TODO: Currently the attributes defined in this file are prefixed with
 // `Builtin_`.  This is to 
diff erentiate the attributes here with the ones in
@@ -71,9 +70,7 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
 // ArrayAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_ArrayAttr : Builtin_Attr<"Array", [
-    SubElementAttrInterface
-  ]> {
+def Builtin_ArrayAttr : Builtin_Attr<"Array"> {
   let summary = "A collection of other Attribute values";
   let description = [{
     Syntax:
@@ -491,9 +488,7 @@ def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [
 // DictionaryAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
-    SubElementAttrInterface
-  ]> {
+def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> {
   let summary = "An dictionary of named Attribute values";
   let description = [{
     Syntax:
@@ -1096,9 +1091,7 @@ def Builtin_StringAttr : Builtin_Attr<"String", [TypedAttrInterface]> {
 // SymbolRefAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", [
-    SubElementAttrInterface
-  ]> {
+def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
   let summary = "An Attribute containing a symbolic reference to an Operation";
   let description = [{
     Syntax:
@@ -1114,13 +1107,6 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", [
     may optionally contain a set of nested references that further resolve to a
     symbol nested within a 
diff erent symbol table.
 
-    This attribute can only be held internally by
-    [array attributes](#array-attribute),
-    [dictionary attributes](#dictionary-attribute)(including the top-level
-    operation attribute dictionary) as well as attributes exposing it via
-    the `SubElementAttrInterface` interface. Symbol reference attributes
-    nested in types are currently not supported.
-
     **Rationale:** Identifying accesses to global data is critical to
     enabling efficient multi-threaded compilation. Restricting global
     data access to occur through symbols and limiting the places that can
@@ -1171,9 +1157,7 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", [
 // TypeAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_TypeAttr : Builtin_Attr<"Type", [
-    SubElementAttrInterface
-  ]> {
+def Builtin_TypeAttr : Builtin_Attr<"Type"> {
   let summary = "An Attribute containing a Type";
   let description = [{
     Syntax:

diff  --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
index 0395e1329590..3c9d2c57f4d1 100644
--- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
@@ -15,7 +15,6 @@
 
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinDialect.td"
-include "mlir/IR/SubElementInterfaces.td"
 
 // Base class for Builtin dialect location attributes.
 class Builtin_LocationAttr<string name, list<Trait> traits = []>
@@ -28,9 +27,7 @@ class Builtin_LocationAttr<string name, list<Trait> traits = []>
 // CallSiteLoc
 //===----------------------------------------------------------------------===//
 
-def CallSiteLoc : Builtin_LocationAttr<"CallSiteLoc", [
-    SubElementAttrInterface
-  ]> {
+def CallSiteLoc : Builtin_LocationAttr<"CallSiteLoc"> {
   let summary = "A callsite source location";
   let description = [{
     Syntax:
@@ -107,9 +104,7 @@ def FileLineColLoc : Builtin_LocationAttr<"FileLineColLoc"> {
 // FusedLoc
 //===----------------------------------------------------------------------===//
 
-def FusedLoc : Builtin_LocationAttr<"FusedLoc", [
-    SubElementAttrInterface
-  ]> {
+def FusedLoc : Builtin_LocationAttr<"FusedLoc"> {
   let summary = "A tuple of other source locations";
   let description = [{
     Syntax:
@@ -148,9 +143,7 @@ def FusedLoc : Builtin_LocationAttr<"FusedLoc", [
 // NameLoc
 //===----------------------------------------------------------------------===//
 
-def NameLoc : Builtin_LocationAttr<"NameLoc", [
-    SubElementAttrInterface
-  ]> {
+def NameLoc : Builtin_LocationAttr<"NameLoc"> {
   let summary = "A named source location";
   let description = [{
     Syntax:
@@ -187,9 +180,7 @@ def NameLoc : Builtin_LocationAttr<"NameLoc", [
 // OpaqueLoc
 //===----------------------------------------------------------------------===//
 
-def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc", [
-    SubElementAttrInterface
-  ]> {
+def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc"> {
   let summary = "An opaque source location";
   let description = [{
     An instance of this location essentially contains a pointer to some data

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f06581a4b0d7..135fa9b559d8 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -11,7 +11,6 @@
 
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
-#include "mlir/IR/SubElementInterfaces.h"
 
 namespace llvm {
 class BitVector;

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 8b7bfafab568..5f9141d71b69 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,7 +17,6 @@
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinDialect.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
-include "mlir/IR/SubElementInterfaces.td"
 
 // TODO: Currently the types defined in this file are prefixed with `Builtin_`.
 // This is to 
diff erentiate the types here with the ones in OpBase.td. We should
@@ -165,9 +164,7 @@ def Builtin_Float128 : Builtin_FloatType<"Float128"> {
 // FunctionType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Function : Builtin_Type<"Function", [
-    DeclareTypeInterfaceMethods<SubElementTypeInterface>
-  ]> {
+def Builtin_Function : Builtin_Type<"Function"> {
   let summary = "Map from a list of inputs to a list of results";
   let description = [{
     Syntax:
@@ -314,7 +311,7 @@ def Builtin_Integer : Builtin_Type<"Integer"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_MemRef : Builtin_Type<"MemRef", [
-    DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
+    ShapedTypeInterface
   ], "BaseMemRefType"> {
   let summary = "Shaped reference to a region of memory";
   let description = [{
@@ -649,7 +646,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
-    DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
+    ShapedTypeInterface
   ], "TensorType"> {
   let summary = "Multi-dimensional array with a fixed number of dimensions";
   let description = [{
@@ -753,9 +750,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
 // TupleType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Tuple : Builtin_Type<"Tuple", [
-    DeclareTypeInterfaceMethods<SubElementTypeInterface>
-  ]> {
+def Builtin_Tuple : Builtin_Type<"Tuple"> {
   let summary = "Fixed-sized collection of other types";
   let description = [{
     Syntax:
@@ -823,7 +818,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", [
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
-    DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
+    ShapedTypeInterface
   ], "BaseMemRefType"> {
   let summary = "Shaped reference, with unknown rank, to a region of memory";
   let description = [{
@@ -895,7 +890,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
-    DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
+    ShapedTypeInterface
   ], "TensorType"> {
   let summary = "Multi-dimensional array with unknown dimensions";
   let description = [{
@@ -943,9 +938,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Vector : Builtin_Type<"Vector", [
-    DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
-  ], "Type"> {
+def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
   let description = [{
     Syntax:

diff  --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 92f55b557baf..78d41d6dc4ab 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -43,13 +43,6 @@ mlir_tablegen(FunctionOpInterfaces.cpp.inc -gen-op-interface-defs)
 add_public_tablegen_target(MLIRFunctionInterfacesIncGen)
 add_dependencies(mlir-generic-headers MLIRFunctionInterfacesIncGen)
 
-set(LLVM_TARGET_DEFINITIONS SubElementInterfaces.td)
-mlir_tablegen(SubElementAttrInterfaces.h.inc -gen-attr-interface-decls)
-mlir_tablegen(SubElementAttrInterfaces.cpp.inc -gen-attr-interface-defs)
-mlir_tablegen(SubElementTypeInterfaces.h.inc -gen-type-interface-decls)
-mlir_tablegen(SubElementTypeInterfaces.cpp.inc -gen-type-interface-defs)
-add_public_tablegen_target(MLIRSubElementInterfacesIncGen)
-
 set(LLVM_TARGET_DEFINITIONS TensorEncoding.td)
 mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls)
 mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs)

diff  --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index b772cf4b90e3..63b12899e249 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -15,7 +15,6 @@
 #define MLIR_IR_LOCATION_H
 
 #include "mlir/IR/Attributes.h"
-#include "mlir/IR/SubElementInterfaces.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
 
 namespace mlir {
@@ -172,13 +171,13 @@ inline OpaqueLoc OpaqueLoc::get(T underlyingLocation, MLIRContext *context) {
 }
 
 //===----------------------------------------------------------------------===//
-// SubElementInterfaces
+// SubElements
 //===----------------------------------------------------------------------===//
 
 /// Enable locations to be introspected as sub-elements.
 template <>
 struct AttrTypeSubElementHandler<Location> {
-  static void walk(Location param, AttrTypeSubElementWalker &walker) {
+  static void walk(Location param, AttrTypeImmediateSubElementWalker &walker) {
     walker.walk(param);
   }
   static Location replace(Location param, AttrSubElementReplacements &attrRepls,

diff  --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 128ad815556c..061bd67e1090 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H
 #define MLIR_IR_STORAGEUNIQUERSUPPORT_H
 
+#include "mlir/IR/AttrTypeSubElements.h"
 #include "mlir/Support/InterfaceSupport.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/StorageUniquer.h"
@@ -126,6 +127,51 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
     };
   }
 
+  /// Walk all of the immediately nested sub-attributes and sub-types. This
+  /// method does not recurse into sub elements.
+  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+                                function_ref<void(Type)> walkTypesFn) const {
+    ::mlir::detail::walkImmediateSubElementsImpl(
+        *static_cast<const ConcreteT *>(this), walkAttrsFn, walkTypesFn);
+  }
+
+  /// Replace the immediately nested sub-attributes and sub-types with those
+  /// provided. The order of the provided elements is derived from the order of
+  /// the elements returned by the callbacks of `walkImmediateSubElements`. The
+  /// element at index 0 would replace the very first attribute given by
+  /// `walkImmediateSubElements`. On success, the new instance with the values
+  /// replaced is returned. If replacement fails, nullptr is returned.
+  ///
+  /// Note that replacing the sub-elements of mutable types or attributes is
+  /// not currently supported by the interface. If an implementing type or
+  /// attribute is mutable, it should return `nullptr` if it has no mechanism
+  /// for replacing sub elements.
+  auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                   ArrayRef<Type> replTypes) const {
+    return ::mlir::detail::replaceImmediateSubElementsImpl(
+        *static_cast<const ConcreteT *>(this), replAttrs, replTypes);
+  }
+
+  /// Returns a function that walks immediate sub elements of a given instance
+  /// of the storage user.
+  static auto getWalkImmediateSubElementsFn() {
+    return [](auto instance, function_ref<void(Attribute)> walkAttrsFn,
+              function_ref<void(Type)> walkTypesFn) {
+      cast<ConcreteT>(instance).walkImmediateSubElements(walkAttrsFn,
+                                                         walkTypesFn);
+    };
+  }
+
+  /// Returns a function that replaces immediate sub elements of a given
+  /// instance of the storage user.
+  static auto getReplaceImmediateSubElementsFn() {
+    return [](auto instance, ArrayRef<Attribute> replAttrs,
+              ArrayRef<Type> replTypes) {
+      return cast<ConcreteT>(instance).replaceImmediateSubElements(replAttrs,
+                                                                   replTypes);
+    };
+  }
+
   /// Attach the given models as implementations of the corresponding interfaces
   /// for the concrete storage user class. The type must be registered with the
   /// context, i.e. the dialect to which the type belongs must be loaded. The

diff  --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td
deleted file mode 100644
index d7054b422181..000000000000
--- a/mlir/include/mlir/IR/SubElementInterfaces.td
+++ /dev/null
@@ -1,142 +0,0 @@
-//===-- SubElementInterfaces.td - Sub-Element Interfaces ---*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file contains a set of interfaces that can be used to interface with
-// sub-elements, e.g. held attributes and types, of a composite attribute or
-// type.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_SUBELEMENTINTERFACES_TD_
-#define MLIR_IR_SUBELEMENTINTERFACES_TD_
-
-include "mlir/IR/OpBase.td"
-
-//===----------------------------------------------------------------------===//
-// SubElementInterfaceBase
-//===----------------------------------------------------------------------===//
-
-class SubElementInterfaceBase<string interfaceName, string attrOrType,
-                              string derivedValue> {
-  string cppNamespace = "::mlir";
-
-  list<InterfaceMethod> methods = [
-    InterfaceMethod<
-      /*desc=*/[{
-        Walk all of the immediately nested sub-attributes and sub-types. This
-        method does not recurse into sub elements.
-      }], "void", "walkImmediateSubElements",
-      (ins "llvm::function_ref<void(mlir::Attribute)>":$walkAttrsFn,
-           "llvm::function_ref<void(mlir::Type)>":$walkTypesFn),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{
-        ::mlir::detail::walkImmediateSubElementsImpl(
-          }] # derivedValue # [{, walkAttrsFn, walkTypesFn);
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Replace the immediately nested sub-attributes and sub-types with those provided.
-        The order of the provided elements is derived from the order of the elements
-        returned by the callbacks of `walkImmediateSubElements`. The element at index 0
-        would replace the very first attribute given by `walkImmediateSubElements`.
-        On success, the new instance with the values replaced is returned. If replacement
-        fails, nullptr is returned.
-
-        Note that replacing the sub-elements of mutable types or attributes is
-        not currently supported by the interface. If an implementing type or
-        attribute is mutable, it should return `nullptr` if it has no mechanism
-        for replacing sub elements.
-      }], attrOrType, "replaceImmediateSubElements",
-      (ins "::llvm::ArrayRef<::mlir::Attribute>":$replAttrs,
-           "::llvm::ArrayRef<::mlir::Type>":$replTypes),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{
-        return ::mlir::detail::replaceImmediateSubElementsImpl(
-           }] # derivedValue # [{, replAttrs, replTypes);
-      }]>,
-  ];
-
-  code extraClassDeclaration = [{
-    /// Walk all of the held sub-attributes and sub-types.
-    void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
-                         llvm::function_ref<void(mlir::Type)> walkTypesFn);
-
-    /// Recursively replace all of the nested sub-attributes and sub-types using the
-    /// provided map functions. Returns nullptr in the case of failure. See
-    /// `AttrTypeReplacer` for information on the support replacement function types.
-    template <typename... ReplacementFns>
-    }] # attrOrType # [{ replaceSubElements(ReplacementFns &&... replacementFns) {
-      AttrTypeReplacer replacer;
-      (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)), ...);
-      return replacer.replace(*this);
-    }
-  }];
-  code extraTraitClassDeclaration = [{
-    /// Walk all of the held sub-attributes and sub-types.
-    void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
-                         llvm::function_ref<void(mlir::Type)> walkTypesFn) {
-      }] # interfaceName # " interface(" # derivedValue # [{);
-      interface.walkSubElements(walkAttrsFn, walkTypesFn);
-    }
-
-    /// Recursively replace all of the nested sub-attributes and sub-types using the
-    /// provided map functions. Returns nullptr in the case of failure. See
-    /// `AttrTypeReplacer` for information on the support replacement function types.
-    template <typename... ReplacementFns>
-    }] # attrOrType # [{ replaceSubElements(ReplacementFns &&... replacementFns) {
-      AttrTypeReplacer replacer;
-      (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)), ...);
-      return replacer.replace(}] # derivedValue # [{);
-    }
-  }];
-  code extraSharedClassDeclaration = [{
-    /// Walk all of the held sub-attributes.
-    void walkSubAttrs(llvm::function_ref<void(mlir::Attribute)> walkFn) {
-      walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {});
-    }
-    /// Walk all of the held sub-types.
-    void walkSubTypes(llvm::function_ref<void(mlir::Type)> walkFn) {
-      walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn);
-    }
-  }];
-}
-
-//===----------------------------------------------------------------------===//
-// SubElementAttrInterface
-//===----------------------------------------------------------------------===//
-
-def SubElementAttrInterface
-    : AttrInterface<"SubElementAttrInterface">,
-      SubElementInterfaceBase<"SubElementAttrInterface", "::mlir::Attribute",
-                              "$_attr"> {
-  let description = [{
-    An interface used to query and manipulate sub-elements, such as sub-types
-    and sub-attributes of a composite attribute.
-
-    To support the introspection of custom parameters that hold sub-elements,
-    a specialization of the `AttrTypeSubElementHandler` class must be provided.
-  }];
-}
-
-//===----------------------------------------------------------------------===//
-// SubElementTypeInterface
-//===----------------------------------------------------------------------===//
-
-def SubElementTypeInterface
-    : TypeInterface<"SubElementTypeInterface">,
-      SubElementInterfaceBase<"SubElementTypeInterface", "::mlir::Type",
-                              "$_type"> {
-  let description = [{
-    An interface used to query and manipulate sub-elements, such as sub-types
-    and sub-attributes of a composite type.
-
-    To support the introspection of custom parameters that hold sub-elements,
-    a specialization of the `AttrTypeSubElementHandler` class must be provided.
-  }];
-}
-
-#endif // MLIR_IR_SUBELEMENTINTERFACES_TD_

diff  --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
index b8215a80cb6f..37a7c242ee84 100644
--- a/mlir/include/mlir/IR/TypeRange.h
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -166,13 +166,13 @@ inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
 }
 
 //===----------------------------------------------------------------------===//
-// SubElementInterfaces
+// SubElements
 //===----------------------------------------------------------------------===//
 
 /// Enable TypeRange to be introspected for sub-elements.
 template <>
 struct AttrTypeSubElementHandler<TypeRange> {
-  static void walk(TypeRange param, AttrTypeSubElementWalker &walker) {
+  static void walk(TypeRange param, AttrTypeImmediateSubElementWalker &walker) {
     walker.walkRange(param);
   }
   static TypeRange replace(TypeRange param,

diff  --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index 2bb28fd261eb..2aa6b1a59e86 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -30,6 +30,10 @@ class MLIRContext;
 class AbstractType {
 public:
   using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
+  using WalkImmediateSubElementsFn = function_ref<void(
+      Type, function_ref<void(Attribute)>, function_ref<void(Type)>)>;
+  using ReplaceImmediateSubElementsFn =
+      function_ref<Type(Type, ArrayRef<Attribute>, ArrayRef<Type>)>;
 
   /// Look up the specified abstract type in the MLIRContext and return a
   /// reference to it.
@@ -40,17 +44,23 @@ class AbstractType {
   template <typename T>
   static AbstractType get(Dialect &dialect) {
     return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
-                        T::getTypeID());
+                        T::getWalkImmediateSubElementsFn(),
+                        T::getReplaceImmediateSubElementsFn(), T::getTypeID());
   }
 
   /// This method is used by Dialect objects to register types with
   /// custom TypeIDs.
   /// The use of this method is in general discouraged in favor of
   /// 'get<CustomType>(dialect)';
-  static AbstractType get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
-                          HasTraitFn &&hasTrait, TypeID typeID) {
+  static AbstractType
+  get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
+      HasTraitFn &&hasTrait,
+      WalkImmediateSubElementsFn walkImmediateSubElementsFn,
+      ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
+      TypeID typeID) {
     return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait),
-                        typeID);
+                        walkImmediateSubElementsFn,
+                        replaceImmediateSubElementsFn, typeID);
   }
 
   /// Return the dialect this type was registered to.
@@ -78,14 +88,29 @@ class AbstractType {
   /// Returns true if the type has a particular trait.
   bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
 
+  /// Walk the immediate sub-elements of the given type.
+  void walkImmediateSubElements(Type type,
+                                function_ref<void(Attribute)> walkAttrsFn,
+                                function_ref<void(Type)> walkTypesFn) const;
+
+  /// Replace the immediate sub-elements of the given type.
+  Type replaceImmediateSubElements(Type type, ArrayRef<Attribute> replAttrs,
+                                   ArrayRef<Type> replTypes) const;
+
   /// Return the unique identifier representing the concrete type class.
   TypeID getTypeID() const { return typeID; }
 
 private:
   AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
-               HasTraitFn &&hasTrait, TypeID typeID)
+               HasTraitFn &&hasTrait,
+               WalkImmediateSubElementsFn walkImmediateSubElementsFn,
+               ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
+               TypeID typeID)
       : dialect(dialect), interfaceMap(std::move(interfaceMap)),
-        hasTraitFn(std::move(hasTrait)), typeID(typeID) {}
+        hasTraitFn(std::move(hasTrait)),
+        walkImmediateSubElementsFn(walkImmediateSubElementsFn),
+        replaceImmediateSubElementsFn(replaceImmediateSubElementsFn),
+        typeID(typeID) {}
 
   /// Give StorageUserBase access to the mutable lookup.
   template <typename ConcreteT, typename BaseT, typename StorageT,
@@ -106,6 +131,12 @@ class AbstractType {
   /// Function to check if the type has a particular trait.
   HasTraitFn hasTraitFn;
 
+  /// Function to walk the immediate sub-elements of this type.
+  WalkImmediateSubElementsFn walkImmediateSubElementsFn;
+
+  /// Function to replace the immediate sub-elements of this type.
+  ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn;
+
   /// The unique identifier of the derived Type class.
   const TypeID typeID;
 };

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 9d64a77742ef..cc7ecfb9d468 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -186,11 +186,52 @@ class Type {
   }
 
   /// Return the abstract type descriptor for this type.
-  const AbstractTy &getAbstractType() { return impl->getAbstractType(); }
+  const AbstractTy &getAbstractType() const { return impl->getAbstractType(); }
 
   /// Return the Type implementation.
   ImplType *getImpl() const { return impl; }
 
+  /// Walk all of the immediately nested sub-attributes and sub-types. This
+  /// method does not recurse into sub elements.
+  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+                                function_ref<void(Type)> walkTypesFn) const {
+    getAbstractType().walkImmediateSubElements(*this, walkAttrsFn, walkTypesFn);
+  }
+
+  /// Replace the immediately nested sub-attributes and sub-types with those
+  /// provided. The order of the provided elements is derived from the order of
+  /// the elements returned by the callbacks of `walkImmediateSubElements`. The
+  /// element at index 0 would replace the very first attribute given by
+  /// `walkImmediateSubElements`. On success, the new instance with the values
+  /// replaced is returned. If replacement fails, nullptr is returned.
+  auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                   ArrayRef<Type> replTypes) const {
+    return getAbstractType().replaceImmediateSubElements(*this, replAttrs,
+                                                         replTypes);
+  }
+
+  /// Walk this type and all attibutes/types nested within using the
+  /// provided walk functions. See `AttrTypeWalker` for information on the
+  /// supported walk function types.
+  template <WalkOrder Order = WalkOrder::PostOrder, typename... WalkFns>
+  auto walk(WalkFns &&...walkFns) {
+    AttrTypeWalker walker;
+    (walker.addWalk(std::forward<WalkFns>(walkFns)), ...);
+    return walker.walk<Order>(*this);
+  }
+
+  /// Recursively replace all of the nested sub-attributes and sub-types using
+  /// the provided map functions. Returns nullptr in the case of failure. See
+  /// `AttrTypeReplacer` for information on the support replacement function
+  /// types.
+  template <typename... ReplacementFns>
+  auto replace(ReplacementFns &&...replacementFns) {
+    AttrTypeReplacer replacer;
+    (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)),
+     ...);
+    return replacer.replace(*this);
+  }
+
 protected:
   ImplType *impl{nullptr};
 };

diff  --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index b170c2655f70..51e79339b7bb 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -35,7 +35,7 @@ class WalkResult {
   enum ResultEnum { Interrupt, Advance, Skip } result;
 
 public:
-  WalkResult(ResultEnum result) : result(result) {}
+  WalkResult(ResultEnum result = Advance) : result(result) {}
 
   /// Allow LogicalResult to interrupt the walk on failure.
   WalkResult(LogicalResult result)

diff  --git a/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp b/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
index 9579111bc618..1d292c5fa45e 100644
--- a/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
@@ -80,20 +80,15 @@ IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
 
 void mlir::gpu::populateMemorySpaceAttributeTypeConversions(
     TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
-  typeConverter.addConversion([mapping](Type type) -> std::optional<Type> {
-    auto subElementType = type.dyn_cast_or_null<SubElementTypeInterface>();
-    if (!subElementType)
-      return type;
-    Type newType = subElementType.replaceSubElements(
-        [mapping](Attribute attr) -> std::optional<Attribute> {
-          auto memorySpaceAttr = attr.dyn_cast_or_null<gpu::AddressSpaceAttr>();
-          if (!memorySpaceAttr)
-            return std::nullopt;
-          auto newValue = wrapNumericMemorySpace(
-              attr.getContext(), mapping(memorySpaceAttr.getValue()));
-          return newValue;
-        });
-    return newType;
+  typeConverter.addConversion([mapping](Type type) {
+    return type.replace([mapping](Attribute attr) -> std::optional<Attribute> {
+      auto memorySpaceAttr = attr.dyn_cast_or_null<gpu::AddressSpaceAttr>();
+      if (!memorySpaceAttr)
+        return std::nullopt;
+      auto newValue = wrapNumericMemorySpace(
+          attr.getContext(), mapping(memorySpaceAttr.getValue()));
+      return newValue;
+    });
   });
 }
 

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 0981b3b1cbe2..af5cb6f354a2 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -25,7 +25,6 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/IR/SubElementInterfaces.h"
 #include "mlir/IR/Verifier.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/DenseMap.h"
@@ -841,13 +840,11 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
     }
 
     // For most builtin types, we can simply walk the sub elements.
-    if (auto subElementInterface = dyn_cast<SubElementTypeInterface>(type)) {
-      auto visitFn = [&](auto element) {
-        if (element)
-          (void)printAlias(element);
-      };
-      subElementInterface.walkImmediateSubElements(visitFn, visitFn);
-    }
+    auto visitFn = [&](auto element) {
+      if (element)
+        (void)printAlias(element);
+    };
+    type.walkImmediateSubElements(visitFn, visitFn);
   }
 
   /// Consider the given type to be printed for an alias.

diff  --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/AttrTypeSubElements.cpp
similarity index 53%
rename from mlir/lib/IR/SubElementInterfaces.cpp
rename to mlir/lib/IR/AttrTypeSubElements.cpp
index 528e0cadfa7c..79b04966be6e 100644
--- a/mlir/lib/IR/SubElementInterfaces.cpp
+++ b/mlir/lib/IR/AttrTypeSubElements.cpp
@@ -1,4 +1,4 @@
-//===- SubElementInterfaces.cpp - Attr and Type SubElement Interfaces -----===//
+//===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,96 +6,77 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/IR/SubElementInterfaces.h"
 #include "mlir/IR/Operation.h"
-
-#include "llvm/ADT/DenseSet.h"
 #include <optional>
 
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
-// SubElementInterface
+// AttrTypeWalker
 //===----------------------------------------------------------------------===//
 
-//===----------------------------------------------------------------------===//
-// WalkSubElements
-
-template <typename InterfaceT>
-static void walkSubElementsImpl(InterfaceT interface,
-                                function_ref<void(Attribute)> walkAttrsFn,
-                                function_ref<void(Type)> walkTypesFn,
-                                DenseSet<Attribute> &visitedAttrs,
-                                DenseSet<Type> &visitedTypes) {
-  interface.walkImmediateSubElements(
-      [&](Attribute attr) {
-        // Guard against potentially null inputs. This removes the need for the
-        // derived attribute/type to do it.
-        if (!attr)
-          return;
-
-        // Avoid infinite recursion when visiting sub attributes later, if this
-        // is a mutable attribute.
-        if (LLVM_UNLIKELY(attr.hasTrait<AttributeTrait::IsMutable>())) {
-          if (!visitedAttrs.insert(attr).second)
-            return;
-        }
-
-        // Walk any sub elements first.
-        if (auto interface = attr.dyn_cast<SubElementAttrInterface>())
-          walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
-                              visitedTypes);
+WalkResult AttrTypeWalker::walkImpl(Attribute attr, WalkOrder order) {
+  return walkImpl(attr, attrWalkFns, order);
+}
+WalkResult AttrTypeWalker::walkImpl(Type type, WalkOrder order) {
+  return walkImpl(type, typeWalkFns, order);
+}
 
-        // Walk this attribute.
-        walkAttrsFn(attr);
-      },
-      [&](Type type) {
-        // Guard against potentially null inputs. This removes the need for the
-        // derived attribute/type to do it.
-        if (!type)
-          return;
-
-        // Avoid infinite recursion when visiting sub types later, if this
-        // is a mutable type.
-        if (LLVM_UNLIKELY(type.hasTrait<TypeTrait::IsMutable>())) {
-          if (!visitedTypes.insert(type).second)
-            return;
-        }
+template <typename T, typename WalkFns>
+WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns,
+                                    WalkOrder order) {
+  // Check if we've already walk this element before.
+  auto key = std::make_pair(element.getAsOpaquePointer(), (int)order);
+  auto it = visitedAttrTypes.find(key);
+  if (it != visitedAttrTypes.end())
+    return it->second;
+  visitedAttrTypes.try_emplace(key, WalkResult::advance());
 
-        // Walk any sub elements first.
-        if (auto interface = type.dyn_cast<SubElementTypeInterface>())
-          walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
-                              visitedTypes);
+  // If we are walking in post order, walk the sub elements first.
+  if (order == WalkOrder::PostOrder) {
+    if (walkSubElements(element, order).wasInterrupted())
+      return visitedAttrTypes[key] = WalkResult::interrupt();
+  }
 
-        // Walk this type.
-        walkTypesFn(type);
-      });
-}
+  // Walk this element, bailing if skipped or interrupted.
+  for (auto &walkFn : llvm::reverse(walkFns)) {
+    WalkResult walkResult = walkFn(element);
+    if (walkResult.wasInterrupted())
+      return visitedAttrTypes[key] = WalkResult::interrupt();
+    if (walkResult.wasSkipped())
+      return WalkResult::advance();
+  }
 
-void SubElementAttrInterface::walkSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) {
-  assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
-  DenseSet<Attribute> visitedAttrs;
-  DenseSet<Type> visitedTypes;
-  walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
-                      visitedTypes);
+  // If we are walking in pre-order, walk the sub elements last.
+  if (order == WalkOrder::PreOrder) {
+    if (walkSubElements(element, order).wasInterrupted())
+      return WalkResult::interrupt();
+  }
+  return WalkResult::advance();
 }
 
-void SubElementTypeInterface::walkSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) {
-  assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
-  DenseSet<Attribute> visitedAttrs;
-  DenseSet<Type> visitedTypes;
-  walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
-                      visitedTypes);
+template <typename T>
+WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
+  WalkResult result = WalkResult::advance();
+  auto walkFn = [&](auto element) {
+    if (element && !result.wasInterrupted())
+      result = walkImpl(element, order);
+  };
+  interface.walkImmediateSubElements(walkFn, walkFn);
+  return result.wasInterrupted() ? result : WalkResult::advance();
 }
 
 //===----------------------------------------------------------------------===//
 /// AttrTypeReplacer
 //===----------------------------------------------------------------------===//
 
+void AttrTypeReplacer::addReplacement(ReplaceFn<Attribute> fn) {
+  attrReplacementFns.emplace_back(std::move(fn));
+}
+void AttrTypeReplacer::addReplacement(ReplaceFn<Type> fn) {
+  typeReplacementFns.push_back(std::move(fn));
+}
+
 void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
                                          bool replaceLocs, bool replaceTypes) {
   // Functor that replaces the given element if the new value is 
diff erent,
@@ -157,7 +138,6 @@ void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op,
 
 template <typename T>
 static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
-                                 DenseMap<T, T> &elementMap,
                                  SmallVectorImpl<T> &newElements,
                                  FailureOr<bool> &changed) {
   // Bail early if we failed at any point.
@@ -180,19 +160,18 @@ static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
   }
 }
 
-template <typename InterfaceT, typename T>
-T AttrTypeReplacer::replaceSubElements(InterfaceT interface,
-                                       DenseMap<T, T> &interfaceMap) {
+template <typename T>
+T AttrTypeReplacer::replaceSubElements(T interface) {
   // Walk the current sub-elements, replacing them as necessary.
   SmallVector<Attribute, 16> newAttrs;
   SmallVector<Type, 16> newTypes;
   FailureOr<bool> changed = false;
   interface.walkImmediateSubElements(
       [&](Attribute element) {
-        updateSubElementImpl(element, *this, attrMap, newAttrs, changed);
+        updateSubElementImpl(element, *this, newAttrs, changed);
       },
       [&](Type element) {
-        updateSubElementImpl(element, *this, typeMap, newTypes, changed);
+        updateSubElementImpl(element, *this, newTypes, changed);
       });
   if (failed(changed))
     return nullptr;
@@ -205,12 +184,12 @@ T AttrTypeReplacer::replaceSubElements(InterfaceT interface,
 }
 
 /// Shared implementation of replacing a given attribute or type element.
-template <typename InterfaceT, typename ReplaceFns, typename T>
-T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns,
-                                DenseMap<T, T> &map) {
-  auto [it, inserted] = map.try_emplace(element, element);
+template <typename T, typename ReplaceFns>
+T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
+  const void *opaqueElement = element.getAsOpaquePointer();
+  auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement);
   if (!inserted)
-    return it->second;
+    return T::getFromOpaquePointer(it->second);
 
   T result = element;
   WalkResult walkResult = WalkResult::advance();
@@ -222,34 +201,42 @@ T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns,
   }
 
   // If an error occurred, return nullptr to indicate failure.
-  if (walkResult.wasInterrupted() || !result)
-    return map[element] = nullptr;
+  if (walkResult.wasInterrupted() || !result) {
+    attrTypeMap[opaqueElement] = nullptr;
+    return nullptr;
+  }
 
   // Handle replacing sub-elements if this element is also a container.
   if (!walkResult.wasSkipped()) {
-    if (auto interface = dyn_cast<InterfaceT>(result)) {
-      // Replace the sub elements of this element, bailing if we fail.
-      if (!(result = replaceSubElements(interface, map)))
-        return map[element] = nullptr;
+    // Replace the sub elements of this element, bailing if we fail.
+    if (!(result = replaceSubElements(result))) {
+      attrTypeMap[opaqueElement] = nullptr;
+      return nullptr;
     }
   }
 
-  return map[element] = result;
+  attrTypeMap[opaqueElement] = result.getAsOpaquePointer();
+  return result;
 }
 
 Attribute AttrTypeReplacer::replace(Attribute attr) {
-  return replaceImpl<SubElementAttrInterface>(attr, attrReplacementFns,
-                                              attrMap);
+  return replaceImpl(attr, attrReplacementFns);
 }
 
 Type AttrTypeReplacer::replace(Type type) {
-  return replaceImpl<SubElementTypeInterface>(type, typeReplacementFns,
-                                              typeMap);
+  return replaceImpl(type, typeReplacementFns);
 }
 
 //===----------------------------------------------------------------------===//
-// SubElementInterface Tablegen definitions
+// AttrTypeImmediateSubElementWalker
 //===----------------------------------------------------------------------===//
 
-#include "mlir/IR/SubElementAttrInterfaces.cpp.inc"
-#include "mlir/IR/SubElementTypeInterfaces.cpp.inc"
+void AttrTypeImmediateSubElementWalker::walk(Attribute element) {
+  if (element)
+    walkAttrsFn(element);
+}
+
+void AttrTypeImmediateSubElementWalker::walk(Type element) {
+  if (element)
+    walkTypesFn(element);
+}

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 4e585da7d909..2798944c2df3 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -12,6 +12,23 @@
 using namespace mlir;
 using namespace mlir::detail;
 
+//===----------------------------------------------------------------------===//
+// AbstractAttribute
+//===----------------------------------------------------------------------===//
+
+void AbstractAttribute::walkImmediateSubElements(
+    Attribute attr, function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  walkImmediateSubElementsFn(attr, walkAttrsFn, walkTypesFn);
+}
+
+Attribute
+AbstractAttribute::replaceImmediateSubElements(Attribute attr,
+                                               ArrayRef<Attribute> replAttrs,
+                                               ArrayRef<Type> replTypes) const {
+  return replaceImmediateSubElementsFn(attr, replAttrs, replTypes);
+}
+
 //===----------------------------------------------------------------------===//
 // Attribute
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 355ddd4d450a..8b4fb42e03ea 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRIR
   AffineMap.cpp
   AsmPrinter.cpp
   Attributes.cpp
+  AttrTypeSubElements.cpp
   Block.cpp
   Builders.cpp
   BuiltinAttributeInterfaces.cpp
@@ -26,7 +27,6 @@ add_mlir_library(MLIRIR
   PatternMatch.cpp
   Region.cpp
   RegionKindInterface.cpp
-  SubElementInterfaces.cpp
   SymbolTable.cpp
   TensorEncoding.cpp
   Types.cpp
@@ -54,7 +54,6 @@ add_mlir_library(MLIRIR
   MLIROpAsmInterfaceIncGen
   MLIRRegionKindInterfaceIncGen
   MLIRSideEffectInterfacesIncGen
-  MLIRSubElementInterfacesIncGen
   MLIRSymbolInterfacesIncGen
   MLIRTensorEncodingIncGen
 

diff  --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp
index 701683fdcb5c..00849857b51d 100644
--- a/mlir/lib/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/IR/ExtensibleDialect.cpp
@@ -407,9 +407,10 @@ void ExtensibleDialect::registerDynamicType(
   assert(registered &&
          "Trying to create a new dynamic type with an existing name");
 
-  auto abstractType =
-      AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(),
-                        DynamicType::getHasTraitFn(), typeID);
+  auto abstractType = AbstractType::get(
+      *dialect, DynamicAttr::getInterfaceMap(), DynamicType::getHasTraitFn(),
+      DynamicType::getWalkImmediateSubElementsFn(),
+      DynamicType::getReplaceImmediateSubElementsFn(), typeID);
 
   /// Add the type to the dialect and the type uniquer.
   addType(typeID, std::move(abstractType));
@@ -436,9 +437,10 @@ void ExtensibleDialect::registerDynamicAttr(
   assert(registered &&
          "Trying to create a new dynamic attribute with an existing name");
 
-  auto abstractAttr =
-      AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(),
-                             DynamicAttr::getHasTraitFn(), typeID);
+  auto abstractAttr = AbstractAttribute::get(
+      *dialect, DynamicAttr::getInterfaceMap(), DynamicAttr::getHasTraitFn(),
+      DynamicAttr::getWalkImmediateSubElementsFn(),
+      DynamicAttr::getReplaceImmediateSubElementsFn(), typeID);
 
   /// Add the type to the dialect and the type uniquer.
   addAttribute(typeID, std::move(abstractAttr));

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 4c3b3bb8be5a..446cce3db413 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -485,66 +485,14 @@ LogicalResult detail::verifySymbol(Operation *op) {
 static WalkResult
 walkSymbolRefs(Operation *op,
                function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
-  // Check to see if the operation has any attributes.
-  DictionaryAttr attrDict = op->getAttrDictionary();
-  if (attrDict.empty())
-    return WalkResult::advance();
-
-  // A worklist of a container attribute and the current index into the held
-  // attribute list.
-  struct WorklistItem {
-    SubElementAttrInterface container;
-    SmallVector<Attribute> immediateSubElements;
-
-    explicit WorklistItem(SubElementAttrInterface container) {
-      SmallVector<Attribute> subElements;
-      container.walkImmediateSubElements(
-          [&](Attribute attr) { subElements.push_back(attr); }, [](Type) {});
-      immediateSubElements = std::move(subElements);
-    }
-  };
-
-  SmallVector<WorklistItem, 1> attrWorklist(1, WorklistItem(attrDict));
-  SmallVector<int, 1> curAccessChain(1, /*Value=*/-1);
-
-  // Process the symbol references within the given nested attribute range.
-  auto processAttrs = [&](int &index,
-                          WorklistItem &worklistItem) -> WalkResult {
-    for (Attribute attr :
-         llvm::drop_begin(worklistItem.immediateSubElements, index)) {
-      // Invoke the provided callback if we find a symbol use and check for a
-      // requested interrupt.
-      if (auto symbolRef = attr.dyn_cast<SymbolRefAttr>()) {
+  return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
+      [&](SymbolRefAttr symbolRef) {
         if (callback({op, symbolRef}).wasInterrupted())
           return WalkResult::interrupt();
 
-        /// Check for a nested container attribute, these will also need to be
-        /// walked.
-      } else if (auto interface = attr.dyn_cast<SubElementAttrInterface>()) {
-        attrWorklist.emplace_back(interface);
-        curAccessChain.push_back(-1);
-        return WalkResult::advance();
-      }
-      // Make sure to keep the index counter in sync.
-      ++index;
-    }
-
-    // Pop this container attribute from the worklist.
-    attrWorklist.pop_back();
-    curAccessChain.pop_back();
-    return WalkResult::advance();
-  };
-
-  WalkResult result = WalkResult::advance();
-  do {
-    WorklistItem &item = attrWorklist.back();
-    int &index = curAccessChain.back();
-    ++index;
-
-    // Process the given attribute, which is guaranteed to be a container.
-    result = processAttrs(index, item);
-  } while (!attrWorklist.empty() && !result.wasInterrupted());
-  return result;
+        // Don't walk nested references.
+        return WalkResult::skip();
+      });
 }
 
 /// Walk all of the uses, for any symbol, that are nested within the given

diff  --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 9dc8e6380c79..1d65fccb82b8 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -31,7 +31,7 @@ struct IntegerTypeStorage : public TypeStorage {
       : width(width), signedness(signedness) {}
 
   /// The hash key used for uniquing.
-  using KeyTy = std::pair<unsigned, IntegerType::SignednessSemantics>;
+  using KeyTy = std::tuple<unsigned, IntegerType::SignednessSemantics>;
 
   static llvm::hash_code hashKey(const KeyTy &key) {
     return llvm::hash_value(key);
@@ -44,7 +44,7 @@ struct IntegerTypeStorage : public TypeStorage {
   static IntegerTypeStorage *construct(TypeStorageAllocator &allocator,
                                        KeyTy key) {
     return new (allocator.allocate<IntegerTypeStorage>())
-        IntegerTypeStorage(key.first, key.second);
+        IntegerTypeStorage(std::get<0>(key), std::get<1>(key));
   }
 
   KeyTy getAsKey() const { return KeyTy(width, signedness); }

diff  --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 670974bbf837..070ed4b14686 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -12,6 +12,22 @@
 using namespace mlir;
 using namespace mlir::detail;
 
+//===----------------------------------------------------------------------===//
+// AbstractType
+//===----------------------------------------------------------------------===//
+
+void AbstractType::walkImmediateSubElements(
+    Type type, function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  walkImmediateSubElementsFn(type, walkAttrsFn, walkTypesFn);
+}
+
+Type AbstractType::replaceImmediateSubElements(Type type,
+                                               ArrayRef<Attribute> replAttrs,
+                                               ArrayRef<Type> replTypes) const {
+  return replaceImmediateSubElementsFn(type, replAttrs, replTypes);
+}
+
 //===----------------------------------------------------------------------===//
 // Type
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/IR/test-symbol-rauw.mlir b/mlir/test/IR/test-symbol-rauw.mlir
index c7d48b6c4eb1..fe1c80e8f400 100644
--- a/mlir/test/IR/test-symbol-rauw.mlir
+++ b/mlir/test/IR/test-symbol-rauw.mlir
@@ -76,7 +76,7 @@ module {
 
 // -----
 
-// Check that replacement works in any implementations of SubElementsAttrInterface
+// Check that replacement works in any implementations of SubElements.
 module {
     // CHECK: func private @replaced_foo
     func.func private @symbol_foo() attributes {sym.new_name = "replaced_foo" }

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index f28afe4daf05..bd7c050f9e85 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -21,7 +21,6 @@ include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpAsmInterface.td"
-include "mlir/IR/SubElementInterfaces.td"
 
 // All of the attributes will extend this class.
 class Test_Attr<string name, list<Trait> traits = []>
@@ -120,9 +119,7 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
   let hasCustomAssemblyFormat = 1;
 }
 
-def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
-    SubElementAttrInterface
-  ]> {
+def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess"> {
   let mnemonic = "sub_elements_access";
 
   let parameters = (ins

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index b6a7e275781a..c6c8f58ea552 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -22,7 +22,6 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/IR/SubElementInterfaces.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 
@@ -132,7 +131,6 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
 class TestRecursiveType
     : public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
                                     TestRecursiveTypeStorage,
-                                    ::mlir::SubElementTypeInterface::Trait,
                                     ::mlir::TypeTrait::IsMutable> {
 public:
   using Base::Base;

diff  --git a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
index 295c58f925e8..f39093498e09 100644
--- a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
+++ b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
@@ -8,7 +8,6 @@
 
 #include "LLVMTestBase.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/IR/SubElementInterfaces.h"
 
 using namespace mlir;
 using namespace mlir::LLVM;
@@ -31,33 +30,24 @@ TEST_F(LLVMIRTest, MutualReferencedSubElementTypes) {
   Type barBody[] = {LLVMPointerType::get(fooStructTy)};
   ASSERT_TRUE(succeeded(barStructTy.setBody(barBody, /*isPacked=*/false)));
 
-  auto subElementInterface = fooStructTy.dyn_cast<SubElementTypeInterface>();
-  ASSERT_TRUE(bool(subElementInterface));
   // Test if walkSubElements goes into infinite loops.
   SmallVector<Type, 4> subElementTypes;
-  subElementInterface.walkSubElements(
-      [](Attribute attr) {},
-      [&](Type type) { subElementTypes.push_back(type); });
-  // We don't record LLVMPointerType (because it's immutable), thus
-  // !llvm.ptr<struct<"bar",...>> will be visited twice.
-  ASSERT_EQ(subElementTypes.size(), 5U);
+  fooStructTy.walk([&](Type type) { subElementTypes.push_back(type); });
+  ASSERT_EQ(subElementTypes.size(), 4U);
 
-  // !llvm.ptr<struct<"bar",...>>
+  // !llvm.ptr<struct<"foo",...>>
   ASSERT_TRUE(subElementTypes[0].isa<LLVMPointerType>());
 
-  // !llvm.struct<"foo",...>
+  // !llvm.struct<"bar",...>
   auto structType = subElementTypes[1].dyn_cast<LLVMStructType>();
   ASSERT_TRUE(bool(structType));
-  ASSERT_TRUE(structType.getName().equals("foo"));
+  ASSERT_TRUE(structType.getName().equals("bar"));
 
-  // !llvm.ptr<struct<"foo",...>>
+  // !llvm.ptr<struct<"bar",...>>
   ASSERT_TRUE(subElementTypes[2].isa<LLVMPointerType>());
 
-  // !llvm.struct<"bar",...>
+  // !llvm.struct<"foo",...>
   structType = subElementTypes[3].dyn_cast<LLVMStructType>();
   ASSERT_TRUE(bool(structType));
-  ASSERT_TRUE(structType.getName().equals("bar"));
-
-  // !llvm.ptr<struct<"bar",...>>
-  ASSERT_TRUE(subElementTypes[4].isa<LLVMPointerType>());
+  ASSERT_TRUE(structType.getName().equals("foo"));
 }

diff  --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 13c7762563b1..7c0572ee89fe 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -422,4 +422,25 @@ TEST(SparseElementsAttrTest, GetZero) {
   EXPECT_TRUE(zeroStringValue.getType() == stringTy);
 }
 
+//===----------------------------------------------------------------------===//
+// SubElements
+//===----------------------------------------------------------------------===//
+
+TEST(SubElementTest, Nested) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  BoolAttr trueAttr = builder.getBoolAttr(true);
+  BoolAttr falseAttr = builder.getBoolAttr(false);
+  ArrayAttr boolArrayAttr = builder.getArrayAttr({trueAttr, falseAttr});
+  StringAttr strAttr = builder.getStringAttr("array");
+  DictionaryAttr dictAttr =
+      builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr));
+
+  SmallVector<Attribute> subAttrs;
+  dictAttr.walk([&](Attribute attr) { subAttrs.push_back(attr); });
+  EXPECT_EQ(llvm::ArrayRef(subAttrs),
+            ArrayRef<Attribute>(
+                {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
+}
 } // namespace

diff  --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 9dcfb71744be..7d49283c59c3 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -8,7 +8,6 @@ add_mlir_unittest(MLIRIRTests
   OperationSupportTest.cpp
   PatternMatchTest.cpp
   ShapedTypeTest.cpp
-  SubElementInterfaceTest.cpp
   TypeTest.cpp
 
   DEPENDS

diff  --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index e77e8794d696..9b20d3c1219e 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -41,24 +41,6 @@ TEST(InterfaceTest, OpInterfaceDenseMapKey) {
   EXPECT_FALSE(opSet.contains(op3));
 }
 
-TEST(InterfaceTest, AttrInterfaceDenseMapKey) {
-  MLIRContext context;
-  context.loadDialect<test::TestDialect>();
-
-  OpBuilder builder(&context);
-
-  DenseSet<SubElementAttrInterface> attrSet;
-  auto attr1 = builder.getArrayAttr({});
-  auto attr2 = builder.getI32ArrayAttr({0});
-  auto attr3 = builder.getI32ArrayAttr({1});
-  attrSet.insert(attr1);
-  attrSet.insert(attr2);
-  attrSet.erase(attr1);
-  EXPECT_FALSE(attrSet.contains(attr1));
-  EXPECT_TRUE(attrSet.contains(attr2));
-  EXPECT_FALSE(attrSet.contains(attr3));
-}
-
 TEST(InterfaceTest, TypeInterfaceDenseMapKey) {
   MLIRContext context;
   context.loadDialect<test::TestDialect>();

diff  --git a/mlir/unittests/IR/SubElementInterfaceTest.cpp b/mlir/unittests/IR/SubElementInterfaceTest.cpp
deleted file mode 100644
index ab461f4dc340..000000000000
--- a/mlir/unittests/IR/SubElementInterfaceTest.cpp
+++ /dev/null
@@ -1,36 +0,0 @@
-//===- SubElementInterfaceTest.cpp - SubElementInterface unit tests -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/SubElementInterfaces.h"
-#include "gtest/gtest.h"
-#include <cstdint>
-
-using namespace mlir;
-using namespace mlir::detail;
-
-namespace {
-TEST(SubElementInterfaceTest, Nested) {
-  MLIRContext context;
-  Builder builder(&context);
-
-  BoolAttr trueAttr = builder.getBoolAttr(true);
-  BoolAttr falseAttr = builder.getBoolAttr(false);
-  ArrayAttr boolArrayAttr = builder.getArrayAttr({trueAttr, falseAttr});
-  StringAttr strAttr = builder.getStringAttr("array");
-  DictionaryAttr dictAttr =
-      builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr));
-
-  SmallVector<Attribute> subAttrs;
-  dictAttr.walkSubAttrs([&](Attribute attr) { subAttrs.push_back(attr); });
-  EXPECT_EQ(llvm::ArrayRef(subAttrs),
-            ArrayRef<Attribute>({strAttr, trueAttr, falseAttr, boolArrayAttr}));
-}
-
-} // namespace


        


More information about the Mlir-commits mailing list