[flang-commits] [flang] 38c219b - [mlir] Infer SubElementInterface implementations using the storage KeyTy

River Riddle via flang-commits flang-commits at lists.llvm.org
Fri Nov 4 18:26:50 PDT 2022


Author: River Riddle
Date: 2022-11-04T18:15:03-07:00
New Revision: 38c219b4a8ebe30d781a1ebbb9a9d29b24c28b39

URL: https://github.com/llvm/llvm-project/commit/38c219b4a8ebe30d781a1ebbb9a9d29b24c28b39
DIFF: https://github.com/llvm/llvm-project/commit/38c219b4a8ebe30d781a1ebbb9a9d29b24c28b39.diff

LOG: [mlir] Infer SubElementInterface implementations using the storage KeyTy

The KeyTy of attribute/type storage classes provide enough information for
automatically implementing the necessary sub element interface methods. This
removes the need for derived classes to do it themselves, which is both much
nicer and easier to handle certain invariants (e.g. null handling). In cases where
explicitly handling for parameter types is necessary, they can provide an implementation
of `AttrTypeSubElementHandler` to opt-in to support.

This tickles a few things alias wise, which annoyingly messes with tests that hard
code specific affine map numbers.

Differential Revision: https://reviews.llvm.org/D137374

Added: 
    

Modified: 
    flang/test/Fir/affine-promotion.fir
    mlir/docs/AttributesAndTypes.md
    mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/BuiltinLocationAttributes.td
    mlir/include/mlir/IR/Location.h
    mlir/include/mlir/IR/StorageUniquerSupport.h
    mlir/include/mlir/IR/SubElementInterfaces.h
    mlir/include/mlir/IR/SubElementInterfaces.td
    mlir/include/mlir/IR/TypeRange.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/Location.cpp
    mlir/lib/IR/SubElementInterfaces.cpp
    mlir/lib/IR/TypeDetail.h
    mlir/test/Dialect/Affine/loop-tiling.mlir
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
    mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
    mlir/test/Dialect/SCF/for-loop-specialization.mlir
    mlir/test/Dialect/SCF/parallel-loop-specialization.mlir
    mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir
    mlir/test/IR/affine-map.mlir
    mlir/test/IR/memory-ops.mlir
    mlir/test/Transforms/loop-fusion-2.mlir
    mlir/test/Transforms/normalize-memrefs-ops.mlir
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestAttributes.cpp
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
    mlir/unittests/IR/SubElementInterfaceTest.cpp

Removed: 
    


################################################################################
diff  --git a/flang/test/Fir/affine-promotion.fir b/flang/test/Fir/affine-promotion.fir
index 4879e51a44512..aae35c6ef5659 100644
--- a/flang/test/Fir/affine-promotion.fir
+++ b/flang/test/Fir/affine-promotion.fir
@@ -50,21 +50,21 @@ func.func @loop_with_load_and_store(%a1: !arr_d1, %a2: !arr_d1, %a3: !arr_d1) {
 // CHECK:    %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK:    %[[VAL_4:.*]] = arith.constant 100 : index
 // CHECK:    %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
-// CHECK:    %[[VAL_6:.*]] = affine.apply #map(){{\[}}%[[VAL_3]], %[[VAL_4]]]
+// CHECK:    %[[VAL_6:.*]] = affine.apply #{{.*}}(){{\[}}%[[VAL_3]], %[[VAL_4]]]
 // CHECK:    %[[VAL_7:.*]] = fir.alloca !fir.array<?xf32>, %[[VAL_6]]
 // CHECK:    %[[VAL_8:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
 // CHECK:    %[[VAL_9:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
 // CHECK:    %[[VAL_10:.*]] = fir.convert %[[VAL_7]] : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
-// CHECK:    affine.for %[[VAL_11:.*]] = %[[VAL_3]] to #map1(){{\[}}%[[VAL_4]]] {
-// CHECK:      %[[VAL_12:.*]] = affine.apply #map2(%[[VAL_11]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
+// CHECK:    affine.for %[[VAL_11:.*]] = %[[VAL_3]] to #{{.*}}(){{\[}}%[[VAL_4]]] {
+// CHECK:      %[[VAL_12:.*]] = affine.apply #{{.*}}(%[[VAL_11]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
 // CHECK:      %[[VAL_13:.*]] = affine.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
 // CHECK:      %[[VAL_14:.*]] = affine.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<?xf32>
 // CHECK:      %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32
 // CHECK:      affine.store %[[VAL_15]], %[[VAL_10]]{{\[}}%[[VAL_12]]] : memref<?xf32>
 // CHECK:    }
 // CHECK:    %[[VAL_16:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
-// CHECK:    affine.for %[[VAL_17:.*]] = %[[VAL_3]] to #map1(){{\[}}%[[VAL_4]]] {
-// CHECK:      %[[VAL_18:.*]] = affine.apply #map2(%[[VAL_17]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
+// CHECK:    affine.for %[[VAL_17:.*]] = %[[VAL_3]] to #{{.*}}(){{\[}}%[[VAL_4]]] {
+// CHECK:      %[[VAL_18:.*]] = affine.apply #{{.*}}(%[[VAL_17]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
 // CHECK:      %[[VAL_19:.*]] = affine.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xf32>
 // CHECK:      %[[VAL_20:.*]] = affine.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf32>
 // CHECK:      %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_20]] : f32
@@ -114,18 +114,18 @@ func.func @loop_with_if(%a: !arr_d1, %v: f32) {
 // CHECK:   %[[VAL_5:.*]] = arith.constant 100 : index
 // CHECK:   %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
 // CHECK:   %[[VAL_7:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
-// CHECK:   affine.for %[[VAL_8:.*]] = %[[VAL_3]] to #map(){{\[}}%[[VAL_5]]] {
-// CHECK:     %[[VAL_9:.*]] = affine.apply #map1(%[[VAL_8]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]]
+// CHECK:   affine.for %[[VAL_8:.*]] = %[[VAL_3]] to #{{.*}}(){{\[}}%[[VAL_5]]] {
+// CHECK:     %[[VAL_9:.*]] = affine.apply #{{.*}}(%[[VAL_8]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]]
 // CHECK:     affine.store %[[VAL_1]], %[[VAL_7]]{{\[}}%[[VAL_9]]] : memref<?xf32>
 // CHECK:   }
-// CHECK:   affine.for %[[VAL_10:.*]] = %[[VAL_3]] to #map(){{\[}}%[[VAL_5]]] {
-// CHECK:     %[[VAL_11:.*]] = affine.apply #map1(%[[VAL_10]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]]
+// CHECK:   affine.for %[[VAL_10:.*]] = %[[VAL_3]] to #{{.*}}(){{\[}}%[[VAL_5]]] {
+// CHECK:     %[[VAL_11:.*]] = affine.apply #{{.*}}(%[[VAL_10]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]]
 // CHECK:     affine.store %[[VAL_1]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xf32>
 // CHECK:   }
-// CHECK:   affine.for %[[VAL_12:.*]] = %[[VAL_3]] to #map(){{\[}}%[[VAL_5]]] {
+// CHECK:   affine.for %[[VAL_12:.*]] = %[[VAL_3]] to #{{.*}}(){{\[}}%[[VAL_5]]] {
 // CHECK:     %[[VAL_13:.*]] = arith.subi %[[VAL_12]], %[[VAL_4]] : index
 // CHECK:     affine.if #set(%[[VAL_12]]) {
-// CHECK:       %[[VAL_14:.*]] = affine.apply #map1(%[[VAL_12]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]]
+// CHECK:       %[[VAL_14:.*]] = affine.apply #{{.*}}(%[[VAL_12]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]]
 // CHECK:       affine.store %[[VAL_1]], %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref<?xf32>
 // CHECK:     }
 // CHECK:   }

diff  --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md
index 7e54c2ee0cd1b..d19b1bf443ad7 100644
--- a/mlir/docs/AttributesAndTypes.md
+++ b/mlir/docs/AttributesAndTypes.md
@@ -959,6 +959,8 @@ User defined storage classes must adhere to the following:
 - Provide a method to hash an instance of the `KeyTy`. (Note: This is not
   necessary if an `llvm::DenseMapInfo<KeyTy>` specialization exists)
   - `static llvm::hash_code hashKey(const KeyTy &)`
+- Provide a method to generate the `KeyTy` from an instance of the storage class.
+  - `static KeyTy getAsKey()`
 
 Let's look at an example:
 
@@ -997,6 +999,11 @@ struct ComplexTypeStorage : public TypeStorage {
         ComplexTypeStorage(key.first, key.second);
   }
 
+  /// Construct an instance of the key from this storage class.
+  KeyTy getAsKey() const {
+    return KeyTy(nonZeroParam, integerType);
+  }
+
   /// The parametric data held by the storage class.
   unsigned nonZeroParam;
   Type integerType;

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index b6b1f4c618ffd..0c689d0019f45 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -151,7 +151,7 @@ def LLVM_DIBasicTypeAttr : LLVM_Attr<"DIBasicType", "di_basic_type",
 //===----------------------------------------------------------------------===//
 
 def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ], "DIScopeAttr"> {
   let parameters = (ins
     LLVM_DILanguageParameter:$sourceLanguage,
@@ -168,7 +168,7 @@ def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit", [
 //===----------------------------------------------------------------------===//
 
 def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ], "DITypeAttr"> {
   let parameters = (ins
     LLVM_DITagParameter:$tag,
@@ -188,7 +188,7 @@ def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
 //===----------------------------------------------------------------------===//
 
 def LLVM_DIDerivedTypeAttr : LLVM_Attr<"DIDerivedType", "di_derived_type", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ], "DITypeAttr"> {
   let parameters = (ins
     LLVM_DITagParameter:$tag,
@@ -220,7 +220,7 @@ def LLVM_DIFileAttr : LLVM_Attr<"DIFile", "di_file", /*traits=*/[], "DIScopeAttr
 //===----------------------------------------------------------------------===//
 
 def LLVM_DILexicalBlockAttr : LLVM_Attr<"DILexicalBlock", "di_lexical_block", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ], "DIScopeAttr"> {
   let parameters = (ins
     "DIScopeAttr":$scope,
@@ -244,7 +244,7 @@ def LLVM_DILexicalBlockAttr : LLVM_Attr<"DILexicalBlock", "di_lexical_block", [
 //===----------------------------------------------------------------------===//
 
 def LLVM_DILexicalBlockFile : LLVM_Attr<"DILexicalBlockFile", "di_lexical_block_file", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ], "DIScopeAttr"> {
   let parameters = (ins
     "DIScopeAttr":$scope,
@@ -266,7 +266,7 @@ def LLVM_DILexicalBlockFile : LLVM_Attr<"DILexicalBlockFile", "di_lexical_block_
 //===----------------------------------------------------------------------===//
 
 def LLVM_DILocalVariableAttr : LLVM_Attr<"DILocalVariable", "di_local_variable", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ], "DINodeAttr"> {
   let parameters = (ins
     "DIScopeAttr":$scope,
@@ -296,7 +296,7 @@ def LLVM_DILocalVariableAttr : LLVM_Attr<"DILocalVariable", "di_local_variable",
 //===----------------------------------------------------------------------===//
 
 def LLVM_DISubprogramAttr : LLVM_Attr<"DISubprogram", "di_subprogram", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ], "DIScopeAttr"> {
   let parameters = (ins
     "DICompileUnitAttr":$compileUnit,
@@ -346,7 +346,7 @@ def LLVM_DISubrangeAttr : LLVM_Attr<"DISubrange", "di_subrange", /*traits=*/[],
 //===----------------------------------------------------------------------===//
 
 def LLVM_DISubroutineTypeAttr : LLVM_Attr<"DISubroutineType", "di_subroutine_type", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ], "DITypeAttr"> {
   let parameters = (ins
     LLVM_DICallingConventionParameter:$callingConvention,

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 06eb6cb5f0424..70f47323fc85f 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -72,7 +72,7 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
 //===----------------------------------------------------------------------===//
 
 def Builtin_ArrayAttr : Builtin_Attr<"Array", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ]> {
   let summary = "A collection of other Attribute values";
   let description = [{
@@ -510,7 +510,7 @@ def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [
 //===----------------------------------------------------------------------===//
 
 def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ]> {
   let summary = "An dictionary of named Attribute values";
   let description = [{
@@ -1115,7 +1115,7 @@ def Builtin_StringAttr : Builtin_Attr<"String", [TypedAttrInterface]> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ]> {
   let summary = "An Attribute containing a symbolic reference to an Operation";
   let description = [{
@@ -1190,7 +1190,7 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", [
 //===----------------------------------------------------------------------===//
 
 def Builtin_TypeAttr : Builtin_Attr<"Type", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ]> {
   let summary = "An Attribute containing a Type";
   let description = [{

diff  --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
index ca96fb9e53bbc..0395e13295904 100644
--- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
@@ -29,7 +29,7 @@ class Builtin_LocationAttr<string name, list<Trait> traits = []>
 //===----------------------------------------------------------------------===//
 
 def CallSiteLoc : Builtin_LocationAttr<"CallSiteLoc", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ]> {
   let summary = "A callsite source location";
   let description = [{
@@ -108,7 +108,7 @@ def FileLineColLoc : Builtin_LocationAttr<"FileLineColLoc"> {
 //===----------------------------------------------------------------------===//
 
 def FusedLoc : Builtin_LocationAttr<"FusedLoc", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ]> {
   let summary = "A tuple of other source locations";
   let description = [{
@@ -149,7 +149,7 @@ def FusedLoc : Builtin_LocationAttr<"FusedLoc", [
 //===----------------------------------------------------------------------===//
 
 def NameLoc : Builtin_LocationAttr<"NameLoc", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ]> {
   let summary = "A named source location";
   let description = [{
@@ -188,7 +188,7 @@ def NameLoc : Builtin_LocationAttr<"NameLoc", [
 //===----------------------------------------------------------------------===//
 
 def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    SubElementAttrInterface
   ]> {
   let summary = "An opaque source location";
   let description = [{

diff  --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index 03f6e4e55896e..b772cf4b90e39 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -107,6 +107,9 @@ class Location {
     return LocationAttr(reinterpret_cast<const AttributeStorage *>(pointer));
   }
 
+  /// Support llvm style casting.
+  static bool classof(Attribute attr) { return llvm::isa<LocationAttr>(attr); }
+
 protected:
   /// The internal backing location attribute.
   LocationAttr impl;
@@ -167,6 +170,23 @@ inline OpaqueLoc OpaqueLoc::get(T underlyingLocation, MLIRContext *context) {
   return get(reinterpret_cast<uintptr_t>(underlyingLocation), TypeID::get<T>(),
              UnknownLoc::get(context));
 }
+
+//===----------------------------------------------------------------------===//
+// SubElementInterfaces
+//===----------------------------------------------------------------------===//
+
+/// Enable locations to be introspected as sub-elements.
+template <>
+struct AttrTypeSubElementHandler<Location> {
+  static void walk(Location param, AttrTypeSubElementWalker &walker) {
+    walker.walk(param);
+  }
+  static Location replace(Location param, AttrSubElementReplacements &attrRepls,
+                          TypeSubElementReplacements &typeRepls) {
+    return cast<LocationAttr>(attrRepls.take_front(1)[0]);
+  }
+};
+
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 074764caf33b1..ff5a0630e4fff 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -180,6 +180,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
     return ConcreteT((const typename BaseT::ImplType *)ptr);
   }
 
+  /// Utility for easy access to the storage instance.
+  ImplType *getImpl() const { return static_cast<ImplType *>(this->impl); }
+
 protected:
   /// Mutate the current storage instance. This will not change the unique key.
   /// The arguments are forwarded to 'ConcreteT::mutate'.
@@ -199,9 +202,6 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
     return success();
   }
 
-  /// Utility for easy access to the storage instance.
-  ImplType *getImpl() const { return static_cast<ImplType *>(this->impl); }
-
 private:
   /// Trait to check if T provides a 'ConcreteEntity' type alias.
   template <typename T>

diff  --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/SubElementInterfaces.h
index 2c40e4edfa0fa..0f3045de5f86a 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.h
+++ b/mlir/include/mlir/IR/SubElementInterfaces.h
@@ -23,6 +23,253 @@ template <typename T>
 using SubElementReplFn = function_ref<T(T)>;
 template <typename T>
 using SubElementResultReplFn = function_ref<std::pair<T, WalkResult>(T)>;
+
+//===----------------------------------------------------------------------===//
+/// AttrTypeSubElementHandler
+//===----------------------------------------------------------------------===//
+
+/// This class is used by AttrTypeSubElementHandler instances to walking sub
+/// attributes and types.
+class AttrTypeSubElementWalker {
+public:
+  AttrTypeSubElementWalker(function_ref<void(Attribute)> walkAttrsFn,
+                           function_ref<void(Type)> walkTypesFn)
+      : walkAttrsFn(walkAttrsFn), walkTypesFn(walkTypesFn) {}
+
+  /// Walk an attribute.
+  void walk(Attribute element) {
+    if (element)
+      walkAttrsFn(element);
+  }
+  /// Walk a type.
+  void walk(Type element) {
+    if (element)
+      walkTypesFn(element);
+  }
+  /// Walk a range of attributes or types.
+  template <typename RangeT>
+  void walkRange(RangeT &&elements) {
+    for (auto element : elements)
+      walk(element);
+  }
+
+private:
+  function_ref<void(Attribute)> walkAttrsFn;
+  function_ref<void(Type)> walkTypesFn;
+};
+
+/// This class is used by AttrTypeSubElementHandler instances to process sub
+/// element replacements.
+template <typename T>
+class AttrTypeSubElementReplacements {
+public:
+  AttrTypeSubElementReplacements(ArrayRef<T> repls) : repls(repls) {}
+
+  /// Take the first N replacements as an ArrayRef, dropping them from
+  /// this replacement list.
+  ArrayRef<T> take_front(unsigned n) {
+    ArrayRef<T> elements = repls.take_front(n);
+    repls = repls.drop_front(n);
+    return elements;
+  }
+
+private:
+  /// The current set of replacements.
+  ArrayRef<T> repls;
+};
+using AttrSubElementReplacements = AttrTypeSubElementReplacements<Attribute>;
+using TypeSubElementReplacements = AttrTypeSubElementReplacements<Type>;
+
+/// This class provides support for interacting with the
+/// SubElementInterfaces for 
diff erent types of parameters. An
+/// implementation of this class should be provided for any parameter class
+/// that may contain an attribute or type. There are two main methods of
+/// this class that need to be implemented:
+///
+///  - walk
+///
+///   This method should traverse into any sub elements of the parameter
+///   using the provided walker, or by invoking handlers for sub-types.
+///
+///  - replace
+///
+///   This method should extract any necessary sub elements using the
+///   provided replacer, or by invoking handlers for sub-types. The new
+///   post-replacement parameter value should be returned.
+///
+template <typename T, typename Enable = void>
+struct AttrTypeSubElementHandler {
+  /// Default walk implementation that does nothing.
+  static inline void walk(const T &param, AttrTypeSubElementWalker &walker) {}
+
+  /// Default replace implementation just forwards the parameter.
+  template <typename ParamT>
+  static inline decltype(auto) replace(ParamT &&param,
+                                       AttrSubElementReplacements &attrRepls,
+                                       TypeSubElementReplacements &typeRepls) {
+    return std::forward<ParamT>(param);
+  }
+
+  /// Tag indicating that this handler does not support sub-elements.
+  using DefaultHandlerTag = void;
+};
+
+/// Detect if any of the given parameter types has a sub-element handler.
+namespace detail {
+template <typename T>
+using has_default_sub_element_handler_t = decltype(T::DefaultHandlerTag);
+} // namespace detail
+template <typename... Ts>
+inline constexpr bool has_sub_attr_or_type_v =
+    (!llvm::is_detected<detail::has_default_sub_element_handler_t, Ts>::value ||
+     ...);
+
+/// Implementation for derived Attributes and Types.
+template <typename T>
+struct AttrTypeSubElementHandler<
+    T, std::enable_if_t<std::is_base_of_v<Attribute, T> ||
+                        std::is_base_of_v<Type, T>>> {
+  static void walk(T param, AttrTypeSubElementWalker &walker) {
+    walker.walk(param);
+  }
+  static T replace(T param, AttrSubElementReplacements &attrRepls,
+                   TypeSubElementReplacements &typeRepls) {
+    if (!param)
+      return T();
+    if constexpr (std::is_base_of_v<Attribute, T>) {
+      return cast<T>(attrRepls.take_front(1)[0]);
+    } else if constexpr (!detail::IsInterface<T>::value &&
+                         std::is_base_of_v<Type, T>) {
+      return cast<T>(typeRepls.take_front(1)[0]);
+    }
+  }
+};
+template <>
+struct AttrTypeSubElementHandler<NamedAttribute> {
+  template <typename T>
+  static void walk(T param, AttrTypeSubElementWalker &walker) {
+    walker.walk(param.getName());
+    walker.walk(param.getValue());
+  }
+  template <typename T>
+  static T replace(T param, AttrSubElementReplacements &attrRepls,
+                   TypeSubElementReplacements &typeRepls) {
+    ArrayRef<Attribute> paramRepls = attrRepls.take_front(2);
+    return T(cast<decltype(param.getName())>(paramRepls[0]), paramRepls[1]);
+  }
+};
+/// Implementation for derived ArrayRef.
+template <typename T>
+struct AttrTypeSubElementHandler<ArrayRef<T>,
+                                 std::enable_if_t<has_sub_attr_or_type_v<T>>> {
+  using EltHandler = AttrTypeSubElementHandler<T>;
+
+  static void walk(ArrayRef<T> param, AttrTypeSubElementWalker &walker) {
+    for (const T &subElement : param)
+      EltHandler::walk(subElement, walker);
+  }
+  static auto replace(ArrayRef<T> param, AttrSubElementReplacements &attrRepls,
+                      TypeSubElementReplacements &typeRepls) {
+    // Normal attributes/types can extract using the replacer directly.
+    if constexpr (std::is_base_of_v<Attribute, T> &&
+                  sizeof(T) == sizeof(Attribute)) {
+      ArrayRef<Attribute> attrs = attrRepls.take_front(param.size());
+      return ArrayRef<T>((const T *)attrs.data(), attrs.size());
+    } else if constexpr (std::is_base_of_v<Type, T> &&
+                         sizeof(T) == sizeof(Type)) {
+      ArrayRef<Type> types = typeRepls.take_front(param.size());
+      return ArrayRef<T>((const T *)types.data(), types.size());
+    } else {
+      // Otherwise, we need to allocate storage for the new elements.
+      SmallVector<T> newElements;
+      for (const T &element : param)
+        newElements.emplace_back(
+            EltHandler::replace(element, attrRepls, typeRepls));
+      return newElements;
+    }
+  }
+};
+/// Implementation for Tuple.
+template <typename... Ts>
+struct AttrTypeSubElementHandler<
+    std::tuple<Ts...>, std::enable_if_t<has_sub_attr_or_type_v<Ts...>>> {
+  static void walk(const std::tuple<Ts...> &param,
+                   AttrTypeSubElementWalker &walker) {
+    std::apply(
+        [&](auto &&...params) {
+          (AttrTypeSubElementHandler<Ts>::walk(params, walker), ...);
+        },
+        param);
+  }
+  static auto replace(const std::tuple<Ts...> &param,
+                      AttrSubElementReplacements &attrRepls,
+                      TypeSubElementReplacements &typeRepls) {
+    return std::apply(
+        [&](const Ts &...params)
+            -> std::tuple<decltype(AttrTypeSubElementHandler<Ts>::replace(
+                params, attrRepls, typeRepls))...> {
+          return {AttrTypeSubElementHandler<Ts>::replace(params, attrRepls,
+                                                         typeRepls)...};
+        },
+        param);
+  }
+};
+
+namespace detail {
+template <typename T>
+struct is_tuple : public std::false_type {};
+template <typename... Ts>
+struct is_tuple<std::tuple<Ts...>> : public std::true_type {};
+
+/// This function provides the underlying implementation for the
+/// SubElementInterface walk method, using the key type of the derived
+/// attribute/type to interact with the individual parameters.
+template <typename T>
+void walkImmediateSubElementsImpl(T derived,
+                                  function_ref<void(Attribute)> walkAttrsFn,
+                                  function_ref<void(Type)> walkTypesFn) {
+  auto key = static_cast<typename T::ImplType *>(derived.getImpl())->getAsKey();
+
+  // If we don't have any sub-elements, there is nothing to do.
+  if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
+    return;
+  } else {
+    AttrTypeSubElementWalker walker(walkAttrsFn, walkTypesFn);
+    AttrTypeSubElementHandler<decltype(key)>::walk(key, walker);
+  }
+}
+
+/// This function provides the underlying implementation for the
+/// SubElementInterface replace method, using the key type of the derived
+/// attribute/type to interact with the individual parameters.
+template <typename T>
+T replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
+                                  ArrayRef<Type> &replTypes) {
+  auto key = static_cast<typename T::ImplType *>(derived.getImpl())->getAsKey();
+
+  // If we don't have any sub-elements, we can just return the original.
+  if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
+    return derived;
+
+    // Otherwise, we need to replace any necessary sub-elements.
+  } else {
+    AttrSubElementReplacements attrRepls(replAttrs);
+    TypeSubElementReplacements typeRepls(replTypes);
+    auto newKey = AttrTypeSubElementHandler<decltype(key)>::replace(
+        key, attrRepls, typeRepls);
+    if constexpr (is_tuple<decltype(key)>::value) {
+      return std::apply(
+          [&](auto &&...params) {
+            return T::Base::get(derived.getContext(),
+                                std::forward<decltype(params)>(params)...);
+          },
+          newKey);
+    } else {
+      return T::Base::get(derived.getContext(), newKey);
+    }
+  }
+}
+} // namespace detail
 } // namespace mlir
 
 /// Include the definitions of the sub elemnt interfaces.

diff  --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td
index 3718b38238c23..abb5afcc93aa1 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.td
+++ b/mlir/include/mlir/IR/SubElementInterfaces.td
@@ -32,7 +32,11 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
         method does not recurse into sub elements.
       }], "void", "walkImmediateSubElements",
       (ins "llvm::function_ref<void(mlir::Attribute)>":$walkAttrsFn,
-           "llvm::function_ref<void(mlir::Type)>":$walkTypesFn)
+           "llvm::function_ref<void(mlir::Type)>":$walkTypesFn),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{
+        ::mlir::detail::walkImmediateSubElementsImpl(
+          }] # derivedValue # [{, walkAttrsFn, walkTypesFn);
+      }]
     >,
     InterfaceMethod<
       /*desc=*/[{
@@ -47,10 +51,13 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
         not currently supported by the interface. If an implementing type or
         attribute is mutable, it should return `nullptr` if it has no mechanism
         for replacing sub elements.
-      }], attrOrType, "replaceImmediateSubElements", (ins
-        "::llvm::ArrayRef<::mlir::Attribute>":$replAttrs,
-        "::llvm::ArrayRef<::mlir::Type>":$replTypes
-      )>,
+      }], attrOrType, "replaceImmediateSubElements",
+      (ins "::llvm::ArrayRef<::mlir::Attribute>":$replAttrs,
+           "::llvm::ArrayRef<::mlir::Type>":$replTypes),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{
+        return ::mlir::detail::replaceImmediateSubElementsImpl(
+           }] # derivedValue # [{, replAttrs, replTypes);
+      }]>,
   ];
 
   code extraClassDeclaration = [{
@@ -154,6 +161,9 @@ def SubElementAttrInterface
   let description = [{
     An interface used to query and manipulate sub-elements, such as sub-types
     and sub-attributes of a composite attribute.
+    
+    To support the introspection of custom parameters that hold sub-elements,
+    a specialization of the `AttrTypeSubElementHandler` class must be provided.
   }];
 }
 
@@ -168,6 +178,9 @@ def SubElementTypeInterface
   let description = [{
     An interface used to query and manipulate sub-elements, such as sub-types
     and sub-attributes of a composite type.
+    
+    To support the introspection of custom parameters that hold sub-elements,
+    a specialization of the `AttrTypeSubElementHandler` class must be provided.
   }];
 }
 

diff  --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
index 5bbab1f994ece..7f65707e98726 100644
--- a/mlir/include/mlir/IR/TypeRange.h
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -165,6 +165,23 @@ inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
          std::equal(lhs.begin(), lhs.end(), rhs.begin());
 }
 
+//===----------------------------------------------------------------------===//
+// SubElementInterfaces
+//===----------------------------------------------------------------------===//
+
+/// Enable TypeRange to be introspected for sub-elements.
+template <>
+struct AttrTypeSubElementHandler<TypeRange> {
+  static void walk(TypeRange param, AttrTypeSubElementWalker &walker) {
+    walker.walkRange(param);
+  }
+  static TypeRange replace(TypeRange param,
+                           AttrSubElementReplacements &attrRepls,
+                           TypeSubElementReplacements &typeRepls) {
+    return typeRepls.take_front(param.size());
+  }
+};
+
 } // namespace mlir
 
 namespace llvm {

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index 7d0c2297736b2..3927e8169c0f4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -64,169 +64,6 @@ bool DITypeAttr::classof(Attribute attr) {
   return llvm::isa<DIBasicTypeAttr, DISubroutineTypeAttr>(attr);
 }
 
-//===----------------------------------------------------------------------===//
-// DICompileUnitAttr
-//===----------------------------------------------------------------------===//
-
-void DICompileUnitAttr::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkAttrsFn(getFile());
-  walkAttrsFn(getProducer());
-}
-
-Attribute
-DICompileUnitAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                               ArrayRef<Type> replTypes) const {
-  return get(getContext(), getSourceLanguage(), replAttrs[0].cast<DIFileAttr>(),
-             replAttrs[1].cast<StringAttr>(), getIsOptimized(),
-             getEmissionKind());
-}
-
-//===----------------------------------------------------------------------===//
-// DICompositeTypeAttr
-//===----------------------------------------------------------------------===//
-
-void DICompositeTypeAttr::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkAttrsFn(getName());
-  walkAttrsFn(getFile());
-  walkAttrsFn(getScope());
-  for (DINodeAttr element : getElements())
-    walkAttrsFn(element);
-}
-
-Attribute DICompositeTypeAttr::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  ArrayRef<Attribute> elements = replAttrs.drop_front(3);
-  return get(
-      getContext(), getTag(), replAttrs[0].cast<StringAttr>(),
-      cast_or_null<DIFileAttr>(replAttrs[1]), getLine(),
-      cast_or_null<DIScopeAttr>(replAttrs[2]), getSizeInBits(),
-      getAlignInBits(),
-      ArrayRef<DINodeAttr>(static_cast<const DINodeAttr *>(elements.data()),
-                           elements.size()));
-}
-
-//===----------------------------------------------------------------------===//
-// DIDerivedTypeAttr
-//===----------------------------------------------------------------------===//
-
-void DIDerivedTypeAttr::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkAttrsFn(getName());
-  walkAttrsFn(getBaseType());
-}
-
-Attribute
-DIDerivedTypeAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                               ArrayRef<Type> replTypes) const {
-  return get(getContext(), getTag(), replAttrs[0].cast<StringAttr>(),
-             replAttrs[1].cast<DITypeAttr>(), getSizeInBits(), getAlignInBits(),
-             getOffsetInBits());
-}
-
-//===----------------------------------------------------------------------===//
-// DILexicalBlockAttr
-//===----------------------------------------------------------------------===//
-
-void DILexicalBlockAttr::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkAttrsFn(getScope());
-  walkAttrsFn(getFile());
-}
-
-Attribute DILexicalBlockAttr::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  return get(replAttrs[0].cast<DIScopeAttr>(), replAttrs[1].cast<DIFileAttr>(),
-             getLine(), getColumn());
-}
-
-//===----------------------------------------------------------------------===//
-// DILexicalBlockFileAttr
-//===----------------------------------------------------------------------===//
-
-void DILexicalBlockFileAttr::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkAttrsFn(getScope());
-  walkAttrsFn(getFile());
-}
-
-Attribute DILexicalBlockFileAttr::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  return get(replAttrs[0].cast<DIScopeAttr>(), replAttrs[1].cast<DIFileAttr>(),
-             getDescriminator());
-}
-
-//===----------------------------------------------------------------------===//
-// DILocalVariableAttr
-//===----------------------------------------------------------------------===//
-
-void DILocalVariableAttr::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkAttrsFn(getScope());
-  walkAttrsFn(getName());
-  walkAttrsFn(getFile());
-  walkAttrsFn(getType());
-}
-
-Attribute DILocalVariableAttr::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  return get(getContext(), replAttrs[0].cast<DIScopeAttr>(),
-             replAttrs[1].cast<StringAttr>(), replAttrs[2].cast<DIFileAttr>(),
-             getLine(), getArg(), getAlignInBits(),
-             replAttrs[3].cast<DITypeAttr>());
-}
-
-//===----------------------------------------------------------------------===//
-// DISubprogramAttr
-//===----------------------------------------------------------------------===//
-
-void DISubprogramAttr::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkAttrsFn(getCompileUnit());
-  walkAttrsFn(getScope());
-  walkAttrsFn(getName());
-  walkAttrsFn(getLinkageName());
-  walkAttrsFn(getFile());
-  walkAttrsFn(getType());
-}
-
-Attribute
-DISubprogramAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                              ArrayRef<Type> replTypes) const {
-  return get(getContext(), replAttrs[0].cast<DICompileUnitAttr>(),
-             replAttrs[1].cast<DIScopeAttr>(), replAttrs[2].cast<StringAttr>(),
-             replAttrs[3].cast<StringAttr>(), replAttrs[4].cast<DIFileAttr>(),
-             getLine(), getScopeLine(), getSubprogramFlags(),
-             replAttrs[5].cast<DISubroutineTypeAttr>());
-}
-
-//===----------------------------------------------------------------------===//
-// DISubroutineTypeAttr
-//===----------------------------------------------------------------------===//
-
-void DISubroutineTypeAttr::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  for (DITypeAttr type : getTypes())
-    walkAttrsFn(type);
-}
-
-Attribute DISubroutineTypeAttr::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  return get(
-      getContext(), getCallingConvention(),
-      ArrayRef<DITypeAttr>(static_cast<const DITypeAttr *>(replAttrs.data()),
-                           replAttrs.size()));
-}
-
 //===----------------------------------------------------------------------===//
 // LoopOptionsAttrBuilder
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 99fa193185a61..133fc6036931e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -183,20 +183,6 @@ LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
   return dataLayout.getTypePreferredAlignment(getElementType());
 }
 
-//===----------------------------------------------------------------------===//
-// SubElementTypeInterface
-
-void LLVMArrayType::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkTypesFn(getElementType());
-}
-
-Type LLVMArrayType::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  return get(replTypes.front(), getNumElements());
-}
-
 //===----------------------------------------------------------------------===//
 // Function type.
 //===----------------------------------------------------------------------===//
@@ -247,22 +233,6 @@ LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// SubElementTypeInterface
-
-void LLVMFunctionType::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkTypesFn(getReturnType());
-  for (Type type : getParams())
-    walkTypesFn(type);
-}
-
-Type LLVMFunctionType::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  return get(replTypes.front(), replTypes.drop_front(), isVarArg());
-}
-
 //===----------------------------------------------------------------------===//
 // LLVMPointerType
 //===----------------------------------------------------------------------===//
@@ -439,20 +409,6 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// SubElementTypeInterface
-
-void LLVMPointerType::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkTypesFn(getElementType());
-}
-
-Type LLVMPointerType::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  return get(getContext(), replTypes.front(), getAddressSpace());
-}
-
 //===----------------------------------------------------------------------===//
 // Struct type.
 //===----------------------------------------------------------------------===//
@@ -749,17 +705,6 @@ LLVMFixedVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
       emitError, elementType, numElements);
 }
 
-void LLVMFixedVectorType::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkTypesFn(getElementType());
-}
-
-Type LLVMFixedVectorType::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  return get(replTypes[0], getNumElements());
-}
-
 //===----------------------------------------------------------------------===//
 // LLVMScalableVectorType.
 //===----------------------------------------------------------------------===//
@@ -792,17 +737,6 @@ LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
       emitError, elementType, numElements);
 }
 
-void LLVMScalableVectorType::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkTypesFn(getElementType());
-}
-
-Type LLVMScalableVectorType::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  return get(replTypes[0], getMinNumElements());
-}
-
 //===----------------------------------------------------------------------===//
 // Utility functions.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index ed22134d1dcc8..8a3c162f59423 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -43,23 +43,6 @@ void BuiltinDialect::registerAttributes() {
       >();
 }
 
-//===----------------------------------------------------------------------===//
-// ArrayAttr
-//===----------------------------------------------------------------------===//
-
-void ArrayAttr::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  for (Attribute attr : getValue())
-    walkAttrsFn(attr);
-}
-
-Attribute
-ArrayAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                       ArrayRef<Type> replTypes) const {
-  return get(getContext(), replAttrs);
-}
-
 //===----------------------------------------------------------------------===//
 // DictionaryAttr
 //===----------------------------------------------------------------------===//
@@ -217,25 +200,6 @@ DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
   return Base::get(context, ArrayRef<NamedAttribute>());
 }
 
-void DictionaryAttr::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  for (const NamedAttribute &attr : getValue())
-    walkAttrsFn(attr.getValue());
-}
-
-Attribute
-DictionaryAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                            ArrayRef<Type> replTypes) const {
-  std::vector<NamedAttribute> vec = getValue().vec();
-  for (auto &it : llvm::enumerate(replAttrs))
-    vec[it.index()].setValue(it.value());
-
-  // The above only modifies the mapped value, but not the key, and therefore
-  // not the order of the elements. It remains sorted
-  return getWithSorted(getContext(), vec);
-}
-
 //===----------------------------------------------------------------------===//
 // StridedLayoutAttr
 //===----------------------------------------------------------------------===//
@@ -375,24 +339,6 @@ StringAttr SymbolRefAttr::getLeafReference() const {
   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
 }
 
-void SymbolRefAttr::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkAttrsFn(getRootReference());
-  for (FlatSymbolRefAttr ref : getNestedReferences())
-    walkAttrsFn(ref);
-}
-
-Attribute
-SymbolRefAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                           ArrayRef<Type> replTypes) const {
-  ArrayRef<Attribute> rawNestedRefs = replAttrs.drop_front();
-  ArrayRef<FlatSymbolRefAttr> nestedRefs(
-      static_cast<const FlatSymbolRefAttr *>(rawNestedRefs.data()),
-      rawNestedRefs.size());
-  return get(replAttrs[0].cast<StringAttr>(), nestedRefs);
-}
-
 //===----------------------------------------------------------------------===//
 // IntegerAttr
 //===----------------------------------------------------------------------===//
@@ -1812,22 +1758,6 @@ SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// TypeAttr
-//===----------------------------------------------------------------------===//
-
-void TypeAttr::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkTypesFn(getValue());
-}
-
-Attribute
-TypeAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                      ArrayRef<Type> replTypes) const {
-  return get(replTypes[0]);
-}
-
 //===----------------------------------------------------------------------===//
 // Attribute Utilities
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index fe6d6ac3b2c4d..d65c5e9d28b1e 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -187,20 +187,6 @@ FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
   return clone(newArgTypes, newResultTypes);
 }
 
-void FunctionType::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
-    walkTypesFn(type);
-}
-
-Type FunctionType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                               ArrayRef<Type> replTypes) const {
-  unsigned numInputs = getNumInputs();
-  return get(getContext(), replTypes.take_front(numInputs),
-             replTypes.drop_front(numInputs));
-}
-
 //===----------------------------------------------------------------------===//
 // OpaqueType
 //===----------------------------------------------------------------------===//
@@ -258,17 +244,6 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
   return VectorType();
 }
 
-void VectorType::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkTypesFn(getElementType());
-}
-
-Type VectorType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                             ArrayRef<Type> replTypes) const {
-  return get(getShape(), replTypes.front(), getNumScalableDims());
-}
-
 VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
                                  Type elementType) const {
   return VectorType::get(shape.value_or(getShape()), elementType,
@@ -343,20 +318,6 @@ RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
   return checkTensorElementType(emitError, elementType);
 }
 
-void RankedTensorType::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkTypesFn(getElementType());
-  if (Attribute encoding = getEncoding())
-    walkAttrsFn(encoding);
-}
-
-Type RankedTensorType::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  return get(getShape(), replTypes.front(),
-             replAttrs.empty() ? Attribute() : replAttrs.back());
-}
-
 //===----------------------------------------------------------------------===//
 // UnrankedTensorType
 //===----------------------------------------------------------------------===//
@@ -367,17 +328,6 @@ UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
   return checkTensorElementType(emitError, elementType);
 }
 
-void UnrankedTensorType::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkTypesFn(getElementType());
-}
-
-Type UnrankedTensorType::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  return get(replTypes.front());
-}
-
 //===----------------------------------------------------------------------===//
 // BaseMemRefType
 //===----------------------------------------------------------------------===//
@@ -671,24 +621,6 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
-void MemRefType::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkTypesFn(getElementType());
-  if (!getLayout().isIdentity())
-    walkAttrsFn(getLayout());
-  walkAttrsFn(getMemorySpace());
-}
-
-Type MemRefType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                             ArrayRef<Type> replTypes) const {
-  bool hasLayout = replAttrs.size() > 1;
-  return get(getShape(), replTypes[0],
-             hasLayout ? replAttrs[0].dyn_cast<MemRefLayoutAttrInterface>()
-                       : MemRefLayoutAttrInterface(),
-             hasLayout ? replAttrs[1] : replAttrs[0]);
-}
-
 //===----------------------------------------------------------------------===//
 // UnrankedMemRefType
 //===----------------------------------------------------------------------===//
@@ -870,18 +802,6 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
   return success();
 }
 
-void UnrankedMemRefType::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkTypesFn(getElementType());
-  walkAttrsFn(getMemorySpace());
-}
-
-Type UnrankedMemRefType::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  return get(replTypes.front(), replAttrs.front());
-}
-
 //===----------------------------------------------------------------------===//
 /// TupleType
 //===----------------------------------------------------------------------===//
@@ -905,18 +825,6 @@ void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
 /// Return the number of element types.
 size_t TupleType::size() const { return getImpl()->size(); }
 
-void TupleType::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  for (Type type : getTypes())
-    walkTypesFn(type);
-}
-
-Type TupleType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                            ArrayRef<Type> replTypes) const {
-  return get(getContext(), replTypes);
-}
-
 //===----------------------------------------------------------------------===//
 // Type Utilities
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index 8a8801daa1160..dcbf9dcecfe29 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -80,20 +80,6 @@ CallSiteLoc CallSiteLoc::get(Location name, ArrayRef<Location> frames) {
   return CallSiteLoc::get(name, caller);
 }
 
-void CallSiteLoc::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkAttrsFn(getCallee());
-  walkAttrsFn(getCaller());
-}
-
-Attribute
-CallSiteLoc::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                         ArrayRef<Type> replTypes) const {
-  return get(replAttrs[0].cast<LocationAttr>(),
-             replAttrs[1].cast<LocationAttr>());
-}
-
 //===----------------------------------------------------------------------===//
 // FusedLoc
 //===----------------------------------------------------------------------===//
@@ -135,55 +121,3 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
 
   return Base::get(context, locs, metadata);
 }
-
-void FusedLoc::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  for (Attribute attr : getLocations())
-    walkAttrsFn(attr);
-  walkAttrsFn(getMetadata());
-}
-
-Attribute
-FusedLoc::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                      ArrayRef<Type> replTypes) const {
-  SmallVector<Location> newLocs;
-  newLocs.reserve(replAttrs.size() - 1);
-  for (Attribute attr : replAttrs.drop_back())
-    newLocs.push_back(attr.cast<LocationAttr>());
-  return get(getContext(), newLocs, replAttrs.back());
-}
-
-//===----------------------------------------------------------------------===//
-// NameLoc
-//===----------------------------------------------------------------------===//
-
-void NameLoc::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkAttrsFn(getName());
-  walkAttrsFn(getChildLoc());
-}
-
-Attribute NameLoc::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                               ArrayRef<Type> replTypes) const {
-  return get(replAttrs[0].cast<StringAttr>(),
-             replAttrs[1].cast<LocationAttr>());
-}
-
-//===----------------------------------------------------------------------===//
-// OpaqueLoc
-//===----------------------------------------------------------------------===//
-
-void OpaqueLoc::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  walkAttrsFn(getFallbackLocation());
-}
-
-Attribute
-OpaqueLoc::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                       ArrayRef<Type> replTypes) const {
-  return get(getUnderlyingLocation(), getUnderlyingTypeID(),
-             replAttrs[0].cast<LocationAttr>());
-}

diff  --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp
index fd05b9d01eea4..ae0223f0936ef 100644
--- a/mlir/lib/IR/SubElementInterfaces.cpp
+++ b/mlir/lib/IR/SubElementInterfaces.cpp
@@ -27,11 +27,6 @@ static void walkSubElementsImpl(InterfaceT interface,
                                 DenseSet<Type> &visitedTypes) {
   interface.walkImmediateSubElements(
       [&](Attribute attr) {
-        // Guard against potentially null inputs. This removes the need for the
-        // derived attribute/type to do it.
-        if (!attr)
-          return;
-
         // Avoid infinite recursion when visiting sub attributes later, if this
         // is a mutable attribute.
         if (LLVM_UNLIKELY(attr.hasTrait<AttributeTrait::IsMutable>())) {
@@ -48,11 +43,6 @@ static void walkSubElementsImpl(InterfaceT interface,
         walkAttrsFn(attr);
       },
       [&](Type type) {
-        // Guard against potentially null inputs. This removes the need for the
-        // derived attribute/type to do it.
-        if (!type)
-          return;
-
         // Avoid infinite recursion when visiting sub types later, if this
         // is a mutable type.
         if (LLVM_UNLIKELY(type.hasTrait<TypeTrait::IsMutable>())) {
@@ -103,10 +93,6 @@ static void updateSubElementImpl(
     return;
   newElements.push_back(element);
 
-  // Guard against potentially null inputs. We always map null to null.
-  if (!element)
-    return;
-
   // Check for an existing mapping for this element, and walk it if we haven't
   // yet.
   T *mappedElement = &visited[element];

diff  --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 1ae66555715f7..9dc8e6380c795 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -47,6 +47,8 @@ struct IntegerTypeStorage : public TypeStorage {
         IntegerTypeStorage(key.first, key.second);
   }
 
+  KeyTy getAsKey() const { return KeyTy(width, signedness); }
+
   unsigned width : 30;
   IntegerType::SignednessSemantics signedness : 2;
 };
@@ -59,7 +61,7 @@ struct FunctionTypeStorage : public TypeStorage {
         inputsAndResults(inputsAndResults) {}
 
   /// The hash key used for uniquing.
-  using KeyTy = std::pair<TypeRange, TypeRange>;
+  using KeyTy = std::tuple<TypeRange, TypeRange>;
   bool operator==(const KeyTy &key) const {
     if (std::get<0>(key) == getInputs())
       return std::get<1>(key) == getResults();
@@ -69,7 +71,7 @@ struct FunctionTypeStorage : public TypeStorage {
   /// Construction.
   static FunctionTypeStorage *construct(TypeStorageAllocator &allocator,
                                         const KeyTy &key) {
-    TypeRange inputs = key.first, results = key.second;
+    auto [inputs, results] = key;
 
     // Copy the inputs and results into the bump pointer.
     SmallVector<Type, 16> types;
@@ -90,6 +92,8 @@ struct FunctionTypeStorage : public TypeStorage {
     return ArrayRef<Type>(inputsAndResults + numInputs, numResults);
   }
 
+  KeyTy getAsKey() const { return KeyTy(getInputs(), getResults()); }
+
   unsigned numInputs;
   unsigned numResults;
   Type const *inputsAndResults;
@@ -127,6 +131,8 @@ struct TupleTypeStorage final
     return {getTrailingObjects<Type>(), size()};
   }
 
+  KeyTy getAsKey() const { return getTypes(); }
+
   /// The number of tuple elements.
   unsigned numElements;
 };

diff  --git a/mlir/test/Dialect/Affine/loop-tiling.mlir b/mlir/test/Dialect/Affine/loop-tiling.mlir
index b84ffe10867a9..e6c33fd9292fb 100644
--- a/mlir/test/Dialect/Affine/loop-tiling.mlir
+++ b/mlir/test/Dialect/Affine/loop-tiling.mlir
@@ -133,8 +133,8 @@ func.func @tile_with_symbolic_loop_upper_bounds(%arg0: memref<?x?xf32>, %arg1: m
 // CHECK:       memref.dim %{{.*}}, %c0 : memref<?x?xf32>
 // CHECK-NEXT:  affine.for %{{.*}} = 0 to %{{.*}} step 32 {
 // CHECK-NEXT:    affine.for %{{.*}} = 0 to %{{.*}} step 32 {
-// CHECK-NEXT:      affine.for %{{.*}} = #map(%{{.*}}) to min [[$UBMAP]](%{{.*}})[%{{.*}}] {
-// CHECK-NEXT:        affine.for %{{.*}} = #map(%{{.*}}) to min [[$UBMAP]](%{{.*}})[%{{.*}}] {
+// CHECK-NEXT:      affine.for %{{.*}} = #[[$MAP:.*]](%{{.*}}) to min [[$UBMAP]](%{{.*}})[%{{.*}}] {
+// CHECK-NEXT:        affine.for %{{.*}} = #[[$MAP]](%{{.*}}) to min [[$UBMAP]](%{{.*}})[%{{.*}}] {
 // CHECK-NEXT:          affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
 // CHECK-NEXT:          affine.for %{{.*}} = 0 to %{{.*}} {
 // CHECK-NEXT:            affine.load

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index c96f95a1b517a..4ff1f19fe36b5 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -775,9 +775,9 @@ func.func @input_stays_same(%arg0 : memref<?x1x?xf32, strided<[?, 1, 1]>>, %arg1
   return %shape : memref<?x1x?x1x?xf32>
 }
 
-// CHECK:     #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
-// CHECK:     #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
-// CHECK:     #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG:     #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+// CHECK-DAG:     #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
+// CHECK-DAG:     #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK:     func @input_stays_same(
 // CHECK-SAME:  %[[ARG0:.*]]: memref<?x1x?xf32, strided<[?, 1, 1]>>,
 // CHECK-SAME:  %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<?x1x?x1x?xf32>)

diff  --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 393f0f49e15f7..ded7374d3ed82 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -416,7 +416,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a
 // CHECK-NEXT:    affine.for %[[ARG6:.*]] = 0 to 1 {
 // CHECK-NEXT:      %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]])
 // CHECK-NEXT:      %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG3]], %[[TMP1]])
-// CHECK-NEXT:      %[[TMP3:.*]] = affine.apply #map2(%[[ARG5]], %[[ARG6]])
+// CHECK-NEXT:      %[[TMP3:.*]] = affine.apply #{{.*}}(%[[ARG5]], %[[ARG6]])
 // CHECK-NEXT:      affine.load %[[ARG0]][%[[TMP2]], %[[TMP3]]] : memref<1024x1024xf32>
 
 // -----

diff  --git a/mlir/test/Dialect/SCF/for-loop-specialization.mlir b/mlir/test/Dialect/SCF/for-loop-specialization.mlir
index 40e8d7dfe4571..ff66c6c1e47f5 100644
--- a/mlir/test/Dialect/SCF/for-loop-specialization.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-specialization.mlir
@@ -23,7 +23,7 @@ func.func @for(%outer: index, %A: memref<?xf32>, %B: memref<?xf32>,
 // CHECK:           [[CST_0:%.*]] = arith.constant 0 : index
 // CHECK:           [[CST_1:%.*]] = arith.constant 1 : index
 // CHECK:           [[DIM_0:%.*]] = memref.dim [[ARG1]], [[CST_0]] : memref<?xf32>
-// CHECK:           [[MIN:%.*]] = affine.min #map(){{\[}}[[DIM_0]], [[ARG0]]]
+// CHECK:           [[MIN:%.*]] = affine.min #{{.*}}(){{\[}}[[DIM_0]], [[ARG0]]]
 // CHECK:           [[CST_1024:%.*]] = arith.constant 1024 : index
 // CHECK:           [[PRED:%.*]] = arith.cmpi eq, [[MIN]], [[CST_1024]] : index
 // CHECK:           scf.if [[PRED]] {

diff  --git a/mlir/test/Dialect/SCF/parallel-loop-specialization.mlir b/mlir/test/Dialect/SCF/parallel-loop-specialization.mlir
index f03254405bfee..73c823ca8d55e 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-specialization.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-specialization.mlir
@@ -26,8 +26,8 @@ func.func @parallel_loop(%outer_i0: index, %outer_i1: index, %A: memref<?x?xf32>
 // CHECK:           [[VAL_7:%.*]] = arith.constant 1 : index
 // CHECK:           [[VAL_8:%.*]] = memref.dim [[VAL_2]], [[VAL_6]] : memref<?x?xf32>
 // CHECK:           [[VAL_9:%.*]] = memref.dim [[VAL_2]], [[VAL_7]] : memref<?x?xf32>
-// CHECK:           [[VAL_10:%.*]] = affine.min #map(){{\[}}[[VAL_8]], [[VAL_0]]]
-// CHECK:           [[VAL_11:%.*]] = affine.min #map1(){{\[}}[[VAL_9]], [[VAL_1]]]
+// CHECK:           [[VAL_10:%.*]] = affine.min #{{.*}}(){{\[}}[[VAL_8]], [[VAL_0]]]
+// CHECK:           [[VAL_11:%.*]] = affine.min #{{.*}}(){{\[}}[[VAL_9]], [[VAL_1]]]
 // CHECK:           [[VAL_12:%.*]] = arith.constant 1024 : index
 // CHECK:           [[VAL_13:%.*]] = arith.cmpi eq, [[VAL_10]], [[VAL_12]] : index
 // CHECK:           [[VAL_14:%.*]] = arith.constant 64 : index

diff  --git a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
index 897f60b29fdbd..41b0d85b3752e 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
@@ -13,7 +13,7 @@ func.func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
   return
 }
 
-// CHECK:       #map = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
+// CHECK:       #[[$MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
 // CHECK-LABEL:   func @parallel_loop(
 // CHECK-SAME:                        [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: index, [[ARG6:%.*]]: index, [[ARG7:%.*]]: memref<?x?xf32>, [[ARG8:%.*]]: memref<?x?xf32>, [[ARG9:%.*]]: memref<?x?xf32>, [[ARG10:%.*]]: memref<?x?xf32>) {
 // CHECK:           [[C0:%.*]] = arith.constant 0 : index
@@ -22,8 +22,8 @@ func.func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
 // CHECK:           [[V1:%.*]] = arith.muli [[ARG5]], [[C1]] : index
 // CHECK:           [[V2:%.*]] = arith.muli [[ARG6]], [[C4]] : index
 // CHECK:           scf.parallel ([[V3:%.*]], [[V4:%.*]]) = ([[ARG1]], [[ARG2]]) to ([[ARG3]], [[ARG4]]) step ([[V1]], [[V2]]) {
-// CHECK:             [[V5:%.*]] = affine.min #map([[V1]], [[ARG3]], [[V3]])
-// CHECK:             [[V6:%.*]] = affine.min #map([[V2]], [[ARG4]], [[V4]])
+// CHECK:             [[V5:%.*]] = affine.min #[[$MAP]]([[V1]], [[ARG3]], [[V3]])
+// CHECK:             [[V6:%.*]] = affine.min #[[$MAP]]([[V2]], [[ARG4]], [[V4]])
 // CHECK:             scf.parallel ([[V7:%.*]], [[V8:%.*]]) = ([[C0]], [[C0]]) to ([[V5]], [[V6]]) step ([[ARG5]], [[ARG6]]) {
 // CHECK:               [[V9:%.*]] = arith.addi [[V7]], [[V3]] : index
 // CHECK:               [[V10:%.*]] = arith.addi [[V8]], [[V4]] : index
@@ -91,7 +91,7 @@ func.func @tile_nested_innermost() {
 // CHECK:             [[V3:%.*]] = arith.muli [[C1]], [[C1_1]] : index
 // CHECK:             [[V4:%.*]] = arith.muli [[C1]], [[C4]] : index
 // CHECK:             scf.parallel ([[V5:%.*]], [[V6:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V3]], [[V4]]) {
-// CHECK:               [[V7:%.*]] = affine.min #map([[V4]], [[C2]], [[V6]])
+// CHECK:               [[V7:%.*]] = affine.min #{{.*}}([[V4]], [[C2]], [[V6]])
 // CHECK:               scf.parallel ([[V8:%.*]], [[V9:%.*]]) = ([[C0_1]], [[C0_1]]) to ([[V3]], [[V7]]) step ([[C1]], [[C1]]) {
 // CHECK:                 = arith.addi [[V8]], [[V5]] : index
 // CHECK:                 = arith.addi [[V9]], [[V6]] : index
@@ -104,7 +104,7 @@ func.func @tile_nested_innermost() {
 // CHECK:           [[V10:%.*]] = arith.muli [[C1]], [[C1_2]] : index
 // CHECK:           [[V11:%.*]] = arith.muli [[C1]], [[C4_1]] : index
 // CHECK:           scf.parallel ([[V12:%.*]], [[V13:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V10]], [[V11]]) {
-// CHECK:             [[V14:%.*]] = affine.min #map([[V11]], [[C2]], [[V13]])
+// CHECK:             [[V14:%.*]] = affine.min #{{.*}}([[V11]], [[C2]], [[V13]])
 // CHECK:             scf.parallel ([[V15:%.*]], [[V16:%.*]]) = ([[C0_2]], [[C0_2]]) to ([[V10]], [[V14]]) step ([[C1]], [[C1]]) {
 // CHECK:               = arith.addi [[V15]], [[V12]] : index
 // CHECK:               = arith.addi [[V16]], [[V13]] : index

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 628ce3b4535a5..7f850ccbbc4e2 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s -split-input-file | mlir-opt -split-input-file | FileCheck %s
 
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 

diff  --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir
index 56a105630f9d2..977aec2536b1e 100644
--- a/mlir/test/IR/affine-map.mlir
+++ b/mlir/test/IR/affine-map.mlir
@@ -1,10 +1,7 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s | FileCheck %s
 
 // Identity maps used in trivial compositions in MemRefs are optimized away.
-// CHECK-NOT: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0, d1)>
 #map0 = affine_map<(i, j) -> (i, j)>
-
-// CHECK-NOT: #map{{[0-9]*}} = affine_map<(d0, d1)[s0] -> (d0, d1)>
 #map1 = affine_map<(i, j)[s0] -> (i, j)>
 
 // CHECK: #map{{[0-9]*}} = affine_map<() -> (0)>
@@ -194,7 +191,6 @@
 
 // Check if parser can parse affine_map with identifiers that collide with
 // integer types.
-// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0, d1)>
 #map60 = affine_map<(i0, i1) -> (i0, i1)>
 
 // Check if parser can parse affine_map with identifiers that collide with

diff  --git a/mlir/test/IR/memory-ops.mlir b/mlir/test/IR/memory-ops.mlir
index fbbf36d6bc210..c1cfc3bfa0dbf 100644
--- a/mlir/test/IR/memory-ops.mlir
+++ b/mlir/test/IR/memory-ops.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s | FileCheck %s
 
-// CHECK: #map = affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>
+// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>
 
 // CHECK-LABEL: func @alloc() {
 func.func @alloc() {
@@ -17,11 +17,11 @@ func.func @alloc() {
   %1 = memref.alloc(%c0, %c1) : memref<?x?xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
 
   // Test alloc with no dynamic dimensions and one symbol.
-  // CHECK: %{{.*}} = memref.alloc()[%{{.*}}] : memref<2x4xf32, #map, 1>
+  // CHECK: %{{.*}} = memref.alloc()[%{{.*}}] : memref<2x4xf32, #[[$MAP]], 1>
   %2 = memref.alloc()[%c0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1>
 
   // Test alloc with dynamic dimensions and one symbol.
-  // CHECK: %{{.*}} = memref.alloc(%{{.*}})[%{{.*}}] : memref<2x?xf32, #map, 1>
+  // CHECK: %{{.*}} = memref.alloc(%{{.*}})[%{{.*}}] : memref<2x?xf32, #[[$MAP]], 1>
   %3 = memref.alloc(%c1)[%c0] : memref<2x?xf32, affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>, 1>
 
   // Alloc with no mappings.
@@ -48,11 +48,11 @@ func.func @alloca() {
   %1 = memref.alloca(%c0, %c1) : memref<?x?xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
 
   // Test alloca with no dynamic dimensions and one symbol.
-  // CHECK: %{{.*}} = memref.alloca()[%{{.*}}] : memref<2x4xf32, #map, 1>
+  // CHECK: %{{.*}} = memref.alloca()[%{{.*}}] : memref<2x4xf32, #[[$MAP]], 1>
   %2 = memref.alloca()[%c0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1>
 
   // Test alloca with dynamic dimensions and one symbol.
-  // CHECK: %{{.*}} = memref.alloca(%{{.*}})[%{{.*}}] : memref<2x?xf32, #map, 1>
+  // CHECK: %{{.*}} = memref.alloca(%{{.*}})[%{{.*}}] : memref<2x?xf32, #[[$MAP]], 1>
   %3 = memref.alloca(%c1)[%c0] : memref<2x?xf32, affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>, 1>
 
   // Alloca with no mappings, but with alignment.

diff  --git a/mlir/test/Transforms/loop-fusion-2.mlir b/mlir/test/Transforms/loop-fusion-2.mlir
index 729e1dc2d9e80..c1fded7a16bb9 100644
--- a/mlir/test/Transforms/loop-fusion-2.mlir
+++ b/mlir/test/Transforms/loop-fusion-2.mlir
@@ -508,16 +508,16 @@ func.func @fuse_across_dim_mismatch(%arg0: memref<4x4x16x1xf32>, %arg1: memref<1
   }
   return
 }
-// MAXIMAL:      #map = affine_map<(d0, d1) -> (d0 * 16 + d1)>
+// MAXIMAL:      #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 16 + d1)>
 // MAXIMAL-LABEL: func @fuse_across_dim_mismatch
 // MAXIMAL:        memref.alloc() : memref<1x1xf32>
 // MAXIMAL:        affine.for %{{.*}} = 0 to 9 {
 // MAXIMAL-NEXT:    affine.for %{{.*}} = 0 to 9 {
 // MAXIMAL-NEXT:      affine.for %{{.*}} = 0 to 4 {
 // MAXIMAL-NEXT:        affine.for %{{.*}} = 0 to 16 {
-// MAXIMAL-NEXT:          affine.apply #map(%{{.*}}, %{{.*}})
+// MAXIMAL-NEXT:          affine.apply #[[$MAP]](%{{.*}}, %{{.*}})
 // MAXIMAL-NEXT:          affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32>
-// MAXIMAL-NEXT:          affine.apply #map(%{{.*}}, %{{.*}})
+// MAXIMAL-NEXT:          affine.apply #[[$MAP]](%{{.*}}, %{{.*}})
 // MAXIMAL-NEXT:          affine.load %{{.*}}[0, 0] : memref<1x1xf32>
 // MAXIMAL-NEXT:        }
 // MAXIMAL-NEXT:      }

diff  --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir
index b45b62a92e4a6..34420c50a51ab 100644
--- a/mlir/test/Transforms/normalize-memrefs-ops.mlir
+++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir
@@ -29,15 +29,15 @@ func.func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
 // Same test with op_nonnorm, with maps in the arguments and the operations in the function.
 
 // CHECK-LABEL: test_nonnorm
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32, #map>)
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32, #[[MAP:.*]]>)
 func.func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
     %0 = memref.alloc() : memref<1x16x14x14xf32, #map0>
     "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
     memref.dealloc %0 :  memref<1x16x14x14xf32, #map0>
 
-    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32, #map>
-    // CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #map>, memref<1x16x14x14xf32, #map>) -> ()
-    // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32, #map>
+    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32, #[[MAP]]>
+    // CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #[[MAP]]>, memref<1x16x14x14xf32, #[[MAP]]>) -> ()
+    // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32, #[[MAP]]>
     return
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 07cfca121f62d..0c35f81c129b0 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -119,8 +119,7 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
 }
 
 def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface,
-        ["replaceImmediateSubElements"]>
+    SubElementAttrInterface
   ]> {
   let mnemonic = "sub_elements_access";
 

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 28fde0987ac09..4c7639b3ae252 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -150,20 +150,6 @@ void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const {
           << ">";
 }
 
-void TestSubElementsAccessAttr::walkImmediateSubElements(
-    llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
-    llvm::function_ref<void(mlir::Type)> walkTypesFn) const {
-  walkAttrsFn(getFirst());
-  walkAttrsFn(getSecond());
-  walkAttrsFn(getThird());
-}
-
-Attribute TestSubElementsAccessAttr::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  assert(replAttrs.size() == 3 && "invalid number of replacement attributes");
-  return get(getContext(), replAttrs[0], replAttrs[1], replAttrs[2]);
-}
-
 //===----------------------------------------------------------------------===//
 // TestExtern1DI64ElementsAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index ec0e79ebdd834..f34aaa364fac9 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -457,6 +457,13 @@ void DefGen::emitKeyType() {
                         [&](auto &param) { os << param.getCppType(); });
   os << '>';
   storageCls->declare<UsingDeclaration>("KeyTy", std::move(os.str()));
+
+  // Add a method to construct the key type from the storage.
+  Method *m = storageCls->addConstMethod<Method::Inline>("KeyTy", "getAsKey");
+  m->body().indent() << "return KeyTy(";
+  llvm::interleaveComma(params, m->body().indent(),
+                        [&](auto &param) { m->body() << param.getName(); });
+  m->body() << ");";
 }
 
 void DefGen::emitEquals() {

diff  --git a/mlir/unittests/IR/SubElementInterfaceTest.cpp b/mlir/unittests/IR/SubElementInterfaceTest.cpp
index 292628aad5d47..66e29d48f7f47 100644
--- a/mlir/unittests/IR/SubElementInterfaceTest.cpp
+++ b/mlir/unittests/IR/SubElementInterfaceTest.cpp
@@ -23,13 +23,14 @@ TEST(SubElementInterfaceTest, Nested) {
   BoolAttr trueAttr = builder.getBoolAttr(true);
   BoolAttr falseAttr = builder.getBoolAttr(false);
   ArrayAttr boolArrayAttr = builder.getArrayAttr({trueAttr, falseAttr});
+  StringAttr strAttr = builder.getStringAttr("array");
   DictionaryAttr dictAttr =
-      builder.getDictionaryAttr(builder.getNamedAttr("array", boolArrayAttr));
+      builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr));
 
   SmallVector<Attribute> subAttrs;
   dictAttr.walkSubAttrs([&](Attribute attr) { subAttrs.push_back(attr); });
   EXPECT_EQ(llvm::makeArrayRef(subAttrs),
-            ArrayRef<Attribute>({trueAttr, falseAttr, boolArrayAttr}));
+            ArrayRef<Attribute>({strAttr, trueAttr, falseAttr, boolArrayAttr}));
 }
 
 } // namespace


        


More information about the flang-commits mailing list