[Mlir-commits] [mlir] 01eedbc - [mlir] Refactor SubElementInterface replace support

River Riddle llvmlistbot at llvm.org
Tue Jul 26 14:51:50 PDT 2022


Author: River Riddle
Date: 2022-07-26T14:51:22-07:00
New Revision: 01eedbc7c14859c273bbd98693c67f35c59e8d85

URL: https://github.com/llvm/llvm-project/commit/01eedbc7c14859c273bbd98693c67f35c59e8d85
DIFF: https://github.com/llvm/llvm-project/commit/01eedbc7c14859c273bbd98693c67f35c59e8d85.diff

LOG: [mlir] Refactor SubElementInterface replace support

The current support was essentially the amount necessary
to support replacing SymbolRefAttrs, but suffers from various
deficiencies (both ergonomic and functional):

* Replace crashes if unsupported
 This makes it really hard to use safely, given that you don't know
 if you are going to crash or not when using it.

* Types aren't supported
This seems like a simple missed addition when the attribute replacement
support was originally added.

* The ergonomics are weird
It currently uses an index based replacement, which makes the implementations
quite clunky.

This commit refactors support to be a bit more ergonomic, and also
adds support for types in the process. This was also a great oppurtunity
to greatly simplify how replacement is done in the symbol table.

Fixes #56355

Differential Revision: https://reviews.llvm.org/D130589

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/SubElementInterfaces.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/SubElementInterfaces.cpp
    mlir/lib/IR/SymbolTable.cpp
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestAttributes.cpp
    mlir/test/lib/Dialect/Test/TestTypes.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index e415061768fe4..ce69a2e583aeb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -116,6 +116,8 @@ class LLVMArrayType
 
   void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
                                 function_ref<void(Type)> walkTypesFn) const;
+  Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                   ArrayRef<Type> replTypes) const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -177,6 +179,8 @@ class LLVMFunctionType : public Type::TypeBase<LLVMFunctionType, Type,
 
   void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
                                 function_ref<void(Type)> walkTypesFn) const;
+  Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                   ArrayRef<Type> replTypes) const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -244,6 +248,8 @@ class LLVMPointerType
 
   void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
                                 function_ref<void(Type)> walkTypesFn) const;
+  Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                   ArrayRef<Type> replTypes) const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -375,6 +381,8 @@ class LLVMStructType
 
   void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
                                 function_ref<void(Type)> walkTypesFn) const;
+  Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                   ArrayRef<Type> replTypes) const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -408,7 +416,7 @@ class LLVMFixedVectorType
   Type getElementType() const;
 
   /// Returns the number of elements in the fixed vector.
-  unsigned getNumElements();
+  unsigned getNumElements() const;
 
   /// Verifies that the type about to be constructed is well-formed.
   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
@@ -416,6 +424,8 @@ class LLVMFixedVectorType
 
   void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
                                 function_ref<void(Type)> walkTypesFn) const;
+  Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                   ArrayRef<Type> replTypes) const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -450,7 +460,7 @@ class LLVMScalableVectorType
   /// Returns the scaling factor of the number of elements in the vector. The
   /// vector contains at least the resulting number of elements, or any non-zero
   /// multiple of this number.
-  unsigned getMinNumElements();
+  unsigned getMinNumElements() const;
 
   /// Verifies that the type about to be constructed is well-formed.
   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
@@ -458,6 +468,8 @@ class LLVMScalableVectorType
 
   void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
                                 function_ref<void(Type)> walkTypesFn) const;
+  Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                   ArrayRef<Type> replTypes) const;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 503c0209eafa1..67fa0a31b5670 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -72,8 +72,7 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
 //===----------------------------------------------------------------------===//
 
 def Builtin_ArrayAttr : Builtin_Attr<"Array", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface,
-        ["replaceImmediateSubAttribute"]>
+    DeclareAttrInterfaceMethods<SubElementAttrInterface>
   ]> {
   let summary = "A collection of other Attribute values";
   let description = [{
@@ -425,8 +424,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
 //===----------------------------------------------------------------------===//
 
 def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface,
-        ["replaceImmediateSubAttribute"]>
+    DeclareAttrInterfaceMethods<SubElementAttrInterface>
   ]> {
   let summary = "An dictionary of named Attribute values";
   let description = [{
@@ -1046,7 +1044,9 @@ def Builtin_StringAttr : Builtin_Attr<"String"> {
 // SymbolRefAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
+def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", [
+    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+  ]> {
   let summary = "An Attribute containing a symbolic reference to an Operation";
   let description = [{
     Syntax:

diff  --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td
index 9ee90513cdda3..f1aede639cf09 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.td
+++ b/mlir/include/mlir/IR/SubElementInterfaces.td
@@ -21,7 +21,8 @@ include "mlir/IR/OpBase.td"
 // SubElementInterfaceBase
 //===----------------------------------------------------------------------===//
 
-class SubElementInterfaceBase<string interfaceName, string derivedValue> {
+class SubElementInterfaceBase<string interfaceName, string attrOrType,
+                              string derivedValue> {
   string cppNamespace = "::mlir";
 
   list<InterfaceMethod> methods = [
@@ -35,52 +36,78 @@ class SubElementInterfaceBase<string interfaceName, string derivedValue> {
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Replace the attributes identified by the indices with the corresponding
-        value. The index is derived from the order of the attributes returned by
-        the attribute callback of `walkImmediateSubElements`. An index of 0 would
-        replace the very first attribute given by `walkImmediateSubElements`.
-        The new instance with the values replaced is returned.
-      }], cppNamespace # "::" # interfaceName, "replaceImmediateSubAttribute",
-      (ins "::llvm::ArrayRef<std::pair<size_t, ::mlir::Attribute>>":$replacements),
-      [{}],
-      /*defaultImplementation=*/[{
-        llvm_unreachable("Attribute or Type does not support replacing attributes");
-      }]
-    >,
+        Replace the immediately nested sub-attributes and sub-types with those provided.
+        The order of the provided elements is derived from the order of the elements
+        returned by the callbacks of `walkImmediateSubElements`. The element at index 0
+        would replace the very first attribute given by `walkImmediateSubElements`.
+        On success, the new instance with the values replaced is returned. If replacement
+        fails, nullptr is returned.
+      }], attrOrType, "replaceImmediateSubElements", (ins
+        "::llvm::ArrayRef<::mlir::Attribute>":$replAttrs,
+        "::llvm::ArrayRef<::mlir::Type>":$replTypes
+      )>,
   ];
 
   code extraClassDeclaration = [{
-    /// Walk all of the held sub-attributes.
-    void walkSubAttrs(llvm::function_ref<void(mlir::Attribute)> walkFn) {
-      walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {});
-    }
-
-    /// Walk all of the held sub-types.
-    void walkSubTypes(llvm::function_ref<void(mlir::Type)> walkFn) {
-      walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn);
-    }
-
     /// Walk all of the held sub-attributes and sub-types.
     void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
                          llvm::function_ref<void(mlir::Type)> walkTypesFn);
-  }];
 
+    /// Recursively replace all of the nested sub-attributes and sub-types using the
+    /// provided map functions. Returns nullptr in the case of failure.
+    }] # attrOrType # [{ replaceSubElements(
+      llvm::function_ref<mlir::Attribute(mlir::Attribute)> replaceAttrFn,
+      llvm::function_ref<mlir::Type(mlir::Type)> replaceTypeFn
+    );
+  }];
   code extraTraitClassDeclaration = [{
+    /// Walk all of the held sub-attributes and sub-types.
+    void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
+                         llvm::function_ref<void(mlir::Type)> walkTypesFn) {
+      }] # interfaceName # " interface(" # derivedValue # [{);
+      interface.walkSubElements(walkAttrsFn, walkTypesFn);
+    }
+
+    /// Recursively replace all of the nested sub-attributes and sub-types using the
+    /// provided map functions. Returns nullptr in the case of failure.
+    }] # attrOrType # [{ replaceSubElements(
+      llvm::function_ref<mlir::Attribute(mlir::Attribute)> replaceAttrFn,
+      llvm::function_ref<mlir::Type(mlir::Type)> replaceTypeFn) {
+      }] # interfaceName # " interface(" # derivedValue # [{);
+      return interface.replaceSubElements(replaceAttrFn, replaceTypeFn);
+    }
+
+    /// Recursively replace all of the nested sub-attributes and sub-types using the
+    /// provided map functions. Returns nullptr in the case of failure.
+    }] # attrOrType # [{ replaceImmediateSubElements(
+      llvm::ArrayRef<mlir::Attribute> replAttrs,
+      llvm::function_ref<mlir::Type(mlir::Type)> replTypes) {
+      return nullptr;
+    }
+  }];
+  code extraSharedClassDeclaration = [{
     /// Walk all of the held sub-attributes.
     void walkSubAttrs(llvm::function_ref<void(mlir::Attribute)> walkFn) {
       walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {});
     }
-
     /// Walk all of the held sub-types.
     void walkSubTypes(llvm::function_ref<void(mlir::Type)> walkFn) {
       walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn);
     }
-
-    /// Walk all of the held sub-attributes and sub-types.
-    void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
-                         llvm::function_ref<void(mlir::Type)> walkTypesFn) {
-      }] # interfaceName # " interface(" # derivedValue # [{);
-      interface.walkSubElements(walkAttrsFn, walkTypesFn);
+    
+    /// Recursively replace all of the nested sub-attributes using the provided
+    /// map function. Returns nullptr in the case of failure.
+    }] # attrOrType # [{ replaceSubElements(
+      llvm::function_ref<mlir::Attribute(mlir::Attribute)> replaceAttrFn) {
+      return replaceSubElements(
+        replaceAttrFn, [](mlir::Type type) { return type; });
+    }
+    /// Recursively replace all of the nested sub-types using the provided map
+    /// function. Returns nullptr in the case of failure.
+    }] # attrOrType # [{ replaceSubElements(
+      llvm::function_ref<mlir::Type(mlir::Type)> replaceTypeFn) {
+      return replaceSubElements(
+        [](mlir::Attribute attr) { return attr; }, replaceTypeFn);
     }
   }];
 }
@@ -91,7 +118,8 @@ class SubElementInterfaceBase<string interfaceName, string derivedValue> {
 
 def SubElementAttrInterface
     : AttrInterface<"SubElementAttrInterface">,
-      SubElementInterfaceBase<"SubElementAttrInterface", "$_attr"> {
+      SubElementInterfaceBase<"SubElementAttrInterface", "::mlir::Attribute",
+                              "$_attr"> {
   let description = [{
     An interface used to query and manipulate sub-elements, such as sub-types
     and sub-attributes of a composite attribute.
@@ -104,7 +132,8 @@ def SubElementAttrInterface
 
 def SubElementTypeInterface
     : TypeInterface<"SubElementTypeInterface">,
-      SubElementInterfaceBase<"SubElementTypeInterface", "$_type"> {
+      SubElementInterfaceBase<"SubElementTypeInterface", "::mlir::Type",
+                              "$_type"> {
   let description = [{
     An interface used to query and manipulate sub-elements, such as sub-types
     and sub-attributes of a composite type.

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 49d2d8d24963b..3f39a6e07748b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -92,6 +92,11 @@ void LLVMArrayType::walkImmediateSubElements(
   walkTypesFn(getElementType());
 }
 
+Type LLVMArrayType::replaceImmediateSubElements(
+    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+  return get(replTypes.front(), getNumElements());
+}
+
 //===----------------------------------------------------------------------===//
 // Function type.
 //===----------------------------------------------------------------------===//
@@ -166,6 +171,11 @@ void LLVMFunctionType::walkImmediateSubElements(
     walkTypesFn(type);
 }
 
+Type LLVMFunctionType::replaceImmediateSubElements(
+    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+  return get(replTypes.front(), replTypes.drop_front(), isVarArg());
+}
+
 //===----------------------------------------------------------------------===//
 // Pointer type.
 //===----------------------------------------------------------------------===//
@@ -374,6 +384,11 @@ void LLVMPointerType::walkImmediateSubElements(
   walkTypesFn(getElementType());
 }
 
+Type LLVMPointerType::replaceImmediateSubElements(
+    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+  return get(replTypes.front(), getAddressSpace());
+}
+
 //===----------------------------------------------------------------------===//
 // Struct type.
 //===----------------------------------------------------------------------===//
@@ -617,6 +632,13 @@ void LLVMStructType::walkImmediateSubElements(
     walkTypesFn(type);
 }
 
+Type LLVMStructType::replaceImmediateSubElements(
+    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+  // TODO: It's not clear how we support replacing sub-elements of mutable
+  // types.
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // Vector types.
 //===----------------------------------------------------------------------===//
@@ -653,7 +675,7 @@ Type LLVMFixedVectorType::getElementType() const {
   return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
 }
 
-unsigned LLVMFixedVectorType::getNumElements() {
+unsigned LLVMFixedVectorType::getNumElements() const {
   return getImpl()->numElements;
 }
 
@@ -674,6 +696,11 @@ void LLVMFixedVectorType::walkImmediateSubElements(
   walkTypesFn(getElementType());
 }
 
+Type LLVMFixedVectorType::replaceImmediateSubElements(
+    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+  return get(replTypes[0], getNumElements());
+}
+
 //===----------------------------------------------------------------------===//
 // LLVMScalableVectorType.
 //===----------------------------------------------------------------------===//
@@ -696,7 +723,7 @@ Type LLVMScalableVectorType::getElementType() const {
   return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
 }
 
-unsigned LLVMScalableVectorType::getMinNumElements() {
+unsigned LLVMScalableVectorType::getMinNumElements() const {
   return getImpl()->numElements;
 }
 
@@ -720,6 +747,11 @@ void LLVMScalableVectorType::walkImmediateSubElements(
   walkTypesFn(getElementType());
 }
 
+Type LLVMScalableVectorType::replaceImmediateSubElements(
+    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+  return get(replTypes[0], getMinNumElements());
+}
+
 //===----------------------------------------------------------------------===//
 // Utility functions.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 218622ba9d73b..5b282bf868b59 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -54,13 +54,10 @@ void ArrayAttr::walkImmediateSubElements(
     walkAttrsFn(attr);
 }
 
-SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute(
-    ArrayRef<std::pair<size_t, Attribute>> replacements) const {
-  std::vector<Attribute> vector = getValue().vec();
-  for (auto &it : replacements) {
-    vector[it.first] = it.second;
-  }
-  return get(getContext(), vector);
+Attribute
+ArrayAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                       ArrayRef<Type> replTypes) const {
+  return get(getContext(), replAttrs);
 }
 
 //===----------------------------------------------------------------------===//
@@ -227,11 +224,12 @@ void DictionaryAttr::walkImmediateSubElements(
     walkAttrsFn(attr.getValue());
 }
 
-SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute(
-    ArrayRef<std::pair<size_t, Attribute>> replacements) const {
+Attribute
+DictionaryAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                            ArrayRef<Type> replTypes) const {
   std::vector<NamedAttribute> vec = getValue().vec();
-  for (auto &it : replacements)
-    vec[it.first].setValue(it.second);
+  for (auto &it : llvm::enumerate(replAttrs))
+    vec[it.index()].setValue(it.value());
 
   // The above only modifies the mapped value, but not the key, and therefore
   // not the order of the elements. It remains sorted
@@ -326,6 +324,24 @@ StringAttr SymbolRefAttr::getLeafReference() const {
   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
 }
 
+void SymbolRefAttr::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  walkAttrsFn(getRootReference());
+  for (FlatSymbolRefAttr ref : getNestedReferences())
+    walkAttrsFn(ref);
+}
+
+Attribute
+SymbolRefAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                           ArrayRef<Type> replTypes) const {
+  ArrayRef<Attribute> rawNestedRefs = replAttrs.drop_front();
+  ArrayRef<FlatSymbolRefAttr> nestedRefs(
+      static_cast<const FlatSymbolRefAttr *>(rawNestedRefs.data()),
+      rawNestedRefs.size());
+  return get(replAttrs[0].cast<StringAttr>(), nestedRefs);
+}
+
 //===----------------------------------------------------------------------===//
 // IntegerAttr
 //===----------------------------------------------------------------------===//
@@ -1711,3 +1727,9 @@ void TypeAttr::walkImmediateSubElements(
     function_ref<void(Type)> walkTypesFn) const {
   walkTypesFn(getValue());
 }
+
+Attribute
+TypeAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                      ArrayRef<Type> replTypes) const {
+  return get(replTypes[0]);
+}

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 0d8ede99c92e6..5459841c8391a 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -193,6 +193,13 @@ void FunctionType::walkImmediateSubElements(
     walkTypesFn(type);
 }
 
+Type FunctionType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                               ArrayRef<Type> replTypes) const {
+  unsigned numInputs = getNumInputs();
+  return get(getContext(), replTypes.take_front(numInputs),
+             replTypes.drop_front(numInputs));
+}
+
 //===----------------------------------------------------------------------===//
 // OpaqueType
 //===----------------------------------------------------------------------===//
@@ -256,6 +263,11 @@ void VectorType::walkImmediateSubElements(
   walkTypesFn(getElementType());
 }
 
+Type VectorType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                             ArrayRef<Type> replTypes) const {
+  return get(getShape(), replTypes.front(), getNumScalableDims());
+}
+
 VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
                                  Type elementType) const {
   return VectorType::get(shape.value_or(getShape()), elementType,
@@ -338,6 +350,12 @@ void RankedTensorType::walkImmediateSubElements(
     walkAttrsFn(encoding);
 }
 
+Type RankedTensorType::replaceImmediateSubElements(
+    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+  return get(getShape(), replTypes.front(),
+             replAttrs.empty() ? Attribute() : replAttrs.back());
+}
+
 //===----------------------------------------------------------------------===//
 // UnrankedTensorType
 //===----------------------------------------------------------------------===//
@@ -354,6 +372,11 @@ void UnrankedTensorType::walkImmediateSubElements(
   walkTypesFn(getElementType());
 }
 
+Type UnrankedTensorType::replaceImmediateSubElements(
+    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+  return get(replTypes.front());
+}
+
 //===----------------------------------------------------------------------===//
 // BaseMemRefType
 //===----------------------------------------------------------------------===//
@@ -663,6 +686,15 @@ void MemRefType::walkImmediateSubElements(
   walkAttrsFn(getMemorySpace());
 }
 
+Type MemRefType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                             ArrayRef<Type> replTypes) const {
+  bool hasLayout = replAttrs.size() > 1;
+  return get(getShape(), replTypes[0],
+             hasLayout ? replAttrs[0].dyn_cast<MemRefLayoutAttrInterface>()
+                       : MemRefLayoutAttrInterface(),
+             hasLayout ? replAttrs[1] : replAttrs[0]);
+}
+
 //===----------------------------------------------------------------------===//
 // UnrankedMemRefType
 //===----------------------------------------------------------------------===//
@@ -829,6 +861,11 @@ void UnrankedMemRefType::walkImmediateSubElements(
   walkAttrsFn(getMemorySpace());
 }
 
+Type UnrankedMemRefType::replaceImmediateSubElements(
+    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+  return get(replTypes.front(), replAttrs.front());
+}
+
 //===----------------------------------------------------------------------===//
 /// TupleType
 //===----------------------------------------------------------------------===//
@@ -859,6 +896,11 @@ void TupleType::walkImmediateSubElements(
     walkTypesFn(type);
 }
 
+Type TupleType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
+                                            ArrayRef<Type> replTypes) const {
+  return get(getContext(), replTypes);
+}
+
 //===----------------------------------------------------------------------===//
 // Type Utilities
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp
index 4059b99b5db24..f8d47083f11c4 100644
--- a/mlir/lib/IR/SubElementInterfaces.cpp
+++ b/mlir/lib/IR/SubElementInterfaces.cpp
@@ -12,6 +12,13 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// SubElementInterface
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// WalkSubElements
+
 template <typename InterfaceT>
 static void walkSubElementsImpl(InterfaceT interface,
                                 function_ref<void(Attribute)> walkAttrsFn,
@@ -83,6 +90,121 @@ void SubElementTypeInterface::walkSubElements(
                       visitedTypes);
 }
 
+//===----------------------------------------------------------------------===//
+// ReplaceSubElements
+
+/// Return if the given element is mutable.
+static bool isMutable(Attribute attr) {
+  return attr.hasTrait<AttributeTrait::IsMutable>();
+}
+static bool isMutable(Type type) {
+  return type.hasTrait<TypeTrait::IsMutable>();
+}
+
+template <typename InterfaceT, typename T, typename ReplaceSubElementFnT>
+static void updateSubElementImpl(T element, function_ref<T(T)> walkFn,
+                                 DenseMap<T, T> &visited,
+                                 SmallVectorImpl<T> &newElements,
+                                 FailureOr<bool> &changed,
+                                 ReplaceSubElementFnT &&replaceSubElementFn) {
+  // Bail early if we failed at any point.
+  if (failed(changed))
+    return;
+  newElements.push_back(element);
+
+  // Guard against potentially null inputs. We always map null to null.
+  if (!element)
+    return;
+
+  // Check for an existing mapping for this element, and walk it if we haven't
+  // yet.
+  T &mappedElement = visited[element];
+  if (!mappedElement) {
+    // Try walking this element.
+    if (!(mappedElement = walkFn(element))) {
+      changed = failure();
+      return;
+    }
+
+    // Handle replacing sub-elements if this element is also a container.
+    if (auto interface = mappedElement.template dyn_cast<InterfaceT>()) {
+      if (!(mappedElement = replaceSubElementFn(interface))) {
+        changed = failure();
+        return;
+      }
+    }
+  }
+
+  // Update to the mapped element.
+  if (mappedElement != element) {
+    newElements.back() = mappedElement;
+    changed = true;
+  }
+}
+
+template <typename InterfaceT>
+static typename InterfaceT::ValueType
+replaceSubElementsImpl(InterfaceT interface,
+                       function_ref<Attribute(Attribute)> walkAttrsFn,
+                       function_ref<Type(Type)> walkTypesFn,
+                       DenseMap<Attribute, Attribute> &visitedAttrs,
+                       DenseMap<Type, Type> &visitedTypes) {
+  // Walk the current sub-elements, replacing them as necessary.
+  SmallVector<Attribute, 16> newAttrs;
+  SmallVector<Type, 16> newTypes;
+  FailureOr<bool> changed = false;
+  auto replaceSubElementFn = [&](auto subInterface) {
+    return replaceSubElementsImpl(subInterface, walkAttrsFn, walkTypesFn,
+                                  visitedAttrs, visitedTypes);
+  };
+  interface.walkImmediateSubElements(
+      [&](Attribute element) {
+        updateSubElementImpl<SubElementAttrInterface>(
+            element, walkAttrsFn, visitedAttrs, newAttrs, changed,
+            replaceSubElementFn);
+      },
+      [&](Type element) {
+        updateSubElementImpl<SubElementTypeInterface>(
+            element, walkTypesFn, visitedTypes, newTypes, changed,
+            replaceSubElementFn);
+      });
+  if (failed(changed))
+    return {};
+
+  // If the sub-elements didn't change, just return the original value.
+  if (!*changed)
+    return interface;
+
+  // If this element is mutable, we don't support changing its sub elements, the
+  // sub element walk doesn't give us a valid ordering for what we need here. If
+  // we want to support mutable elements, we'll need something more.
+  if (isMutable(interface))
+    return {};
+
+  // Use the new elements during the replacement.
+  return interface.replaceImmediateSubElements(newAttrs, newTypes);
+}
+
+Attribute SubElementAttrInterface::replaceSubElements(
+    function_ref<Attribute(Attribute)> replaceAttrFn,
+    function_ref<Type(Type)> replaceTypeFn) {
+  assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
+  DenseMap<Attribute, Attribute> visitedAttrs;
+  DenseMap<Type, Type> visitedTypes;
+  return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn,
+                                visitedAttrs, visitedTypes);
+}
+
+Type SubElementTypeInterface::replaceSubElements(
+    function_ref<Attribute(Attribute)> replaceAttrFn,
+    function_ref<Type(Type)> replaceTypeFn) {
+  assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
+  DenseMap<Attribute, Attribute> visitedAttrs;
+  DenseMap<Type, Type> visitedTypes;
+  return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn,
+                                visitedAttrs, visitedTypes);
+}
+
 //===----------------------------------------------------------------------===//
 // SubElementInterface Tablegen definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index b55487c84d910..fb56d91f68a6c 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -97,6 +97,18 @@ walkSymbolTable(MutableArrayRef<Region> regions,
   return WalkResult::advance();
 }
 
+/// Walk all of the operations nested under, and including, the given operation,
+/// without traversing into any nested symbol tables. Stops walking if the
+/// result of the callback is anything other than `WalkResult::advance`.
+static Optional<WalkResult>
+walkSymbolTable(Operation *op,
+                function_ref<Optional<WalkResult>(Operation *)> callback) {
+  Optional<WalkResult> result = callback(op);
+  if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>())
+    return result;
+  return walkSymbolTable(op->getRegions(), callback);
+}
+
 //===----------------------------------------------------------------------===//
 // SymbolTable
 //===----------------------------------------------------------------------===//
@@ -465,21 +477,11 @@ LogicalResult detail::verifySymbol(Operation *op) {
 //===----------------------------------------------------------------------===//
 
 /// Walk all of the symbol references within the given operation, invoking the
-/// provided callback for each found use. The callbacks takes as arguments: the
-/// use of the symbol, and the nested access chain to the attribute within the
-/// operation dictionary. An access chain is a set of indices into nested
-/// container attributes. For example, a symbol use in an attribute dictionary
-/// that looks like the following:
-///
-///    {use = [{other_attr, @symbol}]}
-///
-/// May have the following access chain:
-///
-///     [0, 0, 1]
-///
-static WalkResult walkSymbolRefs(
-    Operation *op,
-    function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
+/// provided callback for each found use. The callbacks takes the use of the
+/// symbol.
+static WalkResult
+walkSymbolRefs(Operation *op,
+               function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
   // Check to see if the operation has any attributes.
   DictionaryAttr attrDict = op->getAttrDictionary();
   if (attrDict.empty())
@@ -507,20 +509,19 @@ static WalkResult walkSymbolRefs(
                           WorklistItem &worklistItem) -> WalkResult {
     for (Attribute attr :
          llvm::drop_begin(worklistItem.immediateSubElements, index)) {
-      /// Check for a nested container attribute, these will also need to be
-      /// walked.
-      if (auto interface = attr.dyn_cast<SubElementAttrInterface>()) {
-        attrWorklist.emplace_back(interface);
-        curAccessChain.push_back(-1);
-        return WalkResult::advance();
-      }
-
       // Invoke the provided callback if we find a symbol use and check for a
       // requested interrupt.
-      if (auto symbolRef = attr.dyn_cast<SymbolRefAttr>())
-        if (callback({op, symbolRef}, curAccessChain).wasInterrupted())
+      if (auto symbolRef = attr.dyn_cast<SymbolRefAttr>()) {
+        if (callback({op, symbolRef}).wasInterrupted())
           return WalkResult::interrupt();
 
+        /// Check for a nested container attribute, these will also need to be
+        /// walked.
+      } else if (auto interface = attr.dyn_cast<SubElementAttrInterface>()) {
+        attrWorklist.emplace_back(interface);
+        curAccessChain.push_back(-1);
+        return WalkResult::advance();
+      }
       // Make sure to keep the index counter in sync.
       ++index;
     }
@@ -546,9 +547,9 @@ static WalkResult walkSymbolRefs(
 /// Walk all of the uses, for any symbol, that are nested within the given
 /// regions, invoking the provided callback for each. This does not traverse
 /// into any nested symbol tables.
-static Optional<WalkResult> walkSymbolUses(
-    MutableArrayRef<Region> regions,
-    function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
+static Optional<WalkResult>
+walkSymbolUses(MutableArrayRef<Region> regions,
+               function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
   return walkSymbolTable(regions, [&](Operation *op) -> Optional<WalkResult> {
     // Check that this isn't a potentially unknown symbol table.
     if (isPotentiallyUnknownSymbolTable(op))
@@ -560,9 +561,9 @@ static Optional<WalkResult> walkSymbolUses(
 /// Walk all of the uses, for any symbol, that are nested within the given
 /// operation 'from', invoking the provided callback for each. This does not
 /// traverse into any nested symbol tables.
-static Optional<WalkResult> walkSymbolUses(
-    Operation *from,
-    function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
+static Optional<WalkResult>
+walkSymbolUses(Operation *from,
+               function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
   // If this operation has regions, and it, as well as its dialect, isn't
   // registered then conservatively fail. The operation may define a
   // symbol table, so we can't opaquely know if we should traverse to find
@@ -608,11 +609,20 @@ struct SymbolScope {
                 typename llvm::function_traits<CallbackT>::result_t,
                 void>::value> * = nullptr>
   Optional<WalkResult> walk(CallbackT cback) {
-    return walk([=](SymbolTable::SymbolUse use, ArrayRef<int>) {
+    return walk([=](SymbolTable::SymbolUse use) {
       return cback(use), WalkResult::advance();
     });
   }
 
+  /// Walk all of the operations nested under the current scope without
+  /// traversing into any nested symbol tables.
+  template <typename CallbackT>
+  Optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
+    if (Region *region = limit.dyn_cast<Region *>())
+      return ::walkSymbolTable(*region, cback);
+    return ::walkSymbolTable(limit.get<Operation *>(), cback);
+  }
+
   /// The representation of the symbol within this scope.
   SymbolRefAttr symbol;
 
@@ -723,7 +733,7 @@ static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
 template <typename FromT>
 static Optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
   std::vector<SymbolTable::SymbolUse> uses;
-  auto walkFn = [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
+  auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
     uses.push_back(symbolUse);
     return WalkResult::advance();
   };
@@ -792,7 +802,7 @@ template <typename SymbolT, typename IRUnitT>
 static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
   for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
     // Walk all of the symbol uses looking for a reference to 'symbol'.
-    if (scope.walk([&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
+    if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
           return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
                      ? WalkResult::interrupt()
                      : WalkResult::advance();
@@ -822,50 +832,6 @@ bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
 //===----------------------------------------------------------------------===//
 // SymbolTable::replaceAllSymbolUses
 
-/// Rebuild the given attribute container after replacing all references to a
-/// symbol with the updated attribute in 'accesses'.
-static SubElementAttrInterface rebuildAttrAfterRAUW(
-    SubElementAttrInterface container,
-    ArrayRef<std::pair<SmallVector<int, 1>, SymbolRefAttr>> accesses,
-    unsigned depth) {
-  // Given a range of Attributes, update the ones referred to by the given
-  // access chains to point to the new symbol attribute.
-
-  SmallVector<std::pair<size_t, Attribute>> replacements;
-
-  SmallVector<Attribute> subElements;
-  container.walkImmediateSubElements(
-      [&](Attribute attribute) { subElements.push_back(attribute); },
-      [](Type) {});
-  for (unsigned i = 0, e = accesses.size(); i != e;) {
-    ArrayRef<int> access = accesses[i].first;
-
-    // Check to see if this is a leaf access, i.e. a SymbolRef.
-    if (access.size() == depth + 1) {
-      replacements.emplace_back(access.back(), accesses[i].second);
-      ++i;
-      continue;
-    }
-
-    // Otherwise, this is a container. Collect all of the accesses for this
-    // index and recurse. The recursion here is bounded by the size of the
-    // largest access array.
-    auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) {
-      ArrayRef<int> nextAccess = it.first;
-      return nextAccess.size() > depth + 1 &&
-             nextAccess[depth] == access[depth];
-    });
-    auto result = rebuildAttrAfterRAUW(subElements[access[depth]],
-                                       nestedAccesses, depth + 1);
-    replacements.emplace_back(access[depth], result);
-
-    // Skip over all of the accesses that refer to the nested container.
-    i += nestedAccesses.size();
-  }
-
-  return container.replaceImmediateSubAttribute(replacements);
-}
-
 /// Generates a new symbol reference attribute with a new leaf reference.
 static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
                                         FlatSymbolRefAttr newLeafAttr) {
@@ -880,77 +846,43 @@ static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
 template <typename SymbolT, typename IRUnitT>
 static LogicalResult
 replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
-  // A collection of operations along with their new attribute dictionary.
-  std::vector<std::pair<Operation *, DictionaryAttr>> updatedAttrDicts;
-
-  // The current operation being processed.
-  Operation *curOp = nullptr;
-
-  // The set of access chains into the attribute dictionary of the current
-  // operation, as well as the replacement attribute to use.
-  SmallVector<std::pair<SmallVector<int, 1>, SymbolRefAttr>, 1> accessChains;
-
-  // Generate a new attribute dictionary for the current operation by replacing
-  // references to the old symbol.
-  auto generateNewAttrDict = [&] {
-    auto oldDict = curOp->getAttrDictionary();
-    auto newDict = rebuildAttrAfterRAUW(oldDict, accessChains, /*depth=*/0);
-    return newDict.cast<DictionaryAttr>();
-  };
-
   // Generate a new attribute to replace the given attribute.
   FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
   for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
+    SymbolRefAttr oldAttr = scope.symbol;
     SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
-    auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
-                      ArrayRef<int> accessChain) {
-      SymbolRefAttr useRef = symbolUse.getSymbolRef();
-      if (!isReferencePrefixOf(scope.symbol, useRef))
-        return WalkResult::advance();
 
-      // If we have a valid match, check to see if this is a proper
-      // subreference. If it is, then we will need to generate a 
diff erent new
-      // attribute specifically for this use.
-      SymbolRefAttr replacementRef = newAttr;
-      if (useRef != scope.symbol) {
-        if (scope.symbol.isa<FlatSymbolRefAttr>()) {
-          replacementRef =
-              SymbolRefAttr::get(newSymbol, useRef.getNestedReferences());
-        } else {
-          auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
-          nestedRefs[scope.symbol.getNestedReferences().size() - 1] =
-              newLeafAttr;
-          replacementRef =
-              SymbolRefAttr::get(useRef.getRootReference(), nestedRefs);
+    auto walkFn = [&](Operation *op) -> Optional<WalkResult> {
+      auto remapAttrFn = [&](Attribute attr) -> Attribute {
+        if (attr == oldAttr)
+          return newAttr;
+        // Handle prefix matches.
+        if (SymbolRefAttr symRef = attr.dyn_cast<SymbolRefAttr>()) {
+          if (isReferencePrefixOf(oldAttr, symRef)) {
+            auto oldNestedRefs = oldAttr.getNestedReferences();
+            auto nestedRefs = symRef.getNestedReferences();
+            if (oldNestedRefs.empty())
+              return SymbolRefAttr::get(newSymbol, nestedRefs);
+
+            auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
+            newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
+            return SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs);
+          }
         }
-      }
-
-      // If there was a previous operation, generate a new attribute dict
-      // for it. This means that we've finished processing the current
-      // operation, so generate a new dictionary for it.
-      if (curOp && symbolUse.getUser() != curOp) {
-        updatedAttrDicts.push_back({curOp, generateNewAttrDict()});
-        accessChains.clear();
-      }
-
-      // Record this access.
-      curOp = symbolUse.getUser();
-      accessChains.push_back({llvm::to_vector<1>(accessChain), replacementRef});
+        return attr;
+      };
+      // Generate a new attribute dictionary by replacing references to the old
+      // symbol.
+      auto newDict = op->getAttrDictionary().replaceSubElements(remapAttrFn);
+      if (!newDict)
+        return WalkResult::interrupt();
+
+      op->setAttrs(newDict.template cast<DictionaryAttr>());
       return WalkResult::advance();
     };
-    if (!scope.walk(walkFn))
+    if (!scope.walkSymbolTable(walkFn))
       return failure();
-
-    // Check to see if we have a dangling op that needs to be processed.
-    if (curOp) {
-      updatedAttrDicts.push_back({curOp, generateNewAttrDict()});
-      curOp = nullptr;
-    }
   }
-
-  // Update the attribute dictionaries as necessary.
-  for (auto &it : updatedAttrDicts)
-    it.first->setAttrs(it.second);
   return success();
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 10c82b74282df..fb99d6f7c9b54 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -114,9 +114,8 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
 
 def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
     DeclareAttrInterfaceMethods<SubElementAttrInterface,
-        ["replaceImmediateSubAttribute"]>
+        ["replaceImmediateSubElements"]>
   ]> {
-
   let mnemonic = "sub_elements_access";
 
   let parameters = (ins

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 2decac866afa9..810380cf9ff19 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -173,25 +173,10 @@ void TestSubElementsAccessAttr::walkImmediateSubElements(
   walkAttrsFn(getThird());
 }
 
-SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
-    ArrayRef<std::pair<size_t, Attribute>> replacements) const {
-  Attribute first = getFirst();
-  Attribute second = getSecond();
-  Attribute third = getThird();
-  for (auto &it : replacements) {
-    switch (it.first) {
-    case 0:
-      first = it.second;
-      break;
-    case 1:
-      second = it.second;
-      break;
-    case 2:
-      third = it.second;
-      break;
-    }
-  }
-  return get(getContext(), first, second, third);
+Attribute TestSubElementsAccessAttr::replaceImmediateSubElements(
+    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
+  assert(replAttrs.size() == 3 && "invalid number of replacement attributes");
+  return get(getContext(), replAttrs[0], replAttrs[1], replAttrs[2]);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 8772efd8590fc..2ccf6ddf5f8f7 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -154,6 +154,12 @@ class TestRecursiveType
       ::llvm::function_ref<void(::mlir::Type)> walkTypesFn) const {
     walkTypesFn(getBody());
   }
+  Type replaceImmediateSubElements(llvm::ArrayRef<mlir::Attribute> replAttrs,
+                                   llvm::ArrayRef<mlir::Type> replTypes) const {
+    // TODO: It's not clear how we support replacing sub-elements of mutable
+    // types.
+    return nullptr;
+  }
 };
 
 } // namespace test


        


More information about the Mlir-commits mailing list