[Mlir-commits] [mlir] c42dd5d - [mlir] Add new SubElementAttr/SubElementType Interfaces

River Riddle llvmlistbot at llvm.org
Thu Jun 10 17:28:47 PDT 2021


Author: River Riddle
Date: 2021-06-10T17:23:07-07:00
New Revision: c42dd5dbb015afaef99cf876195c474c63c2393e

URL: https://github.com/llvm/llvm-project/commit/c42dd5dbb015afaef99cf876195c474c63c2393e
DIFF: https://github.com/llvm/llvm-project/commit/c42dd5dbb015afaef99cf876195c474c63c2393e.diff

LOG: [mlir] Add new SubElementAttr/SubElementType Interfaces

These interfaces allow for a composite attribute or type to opaquely provide access to any held attributes or types. There are several intended use cases for this interface. The first of which is to allow the printer to create aliases for non-builtin dialect attributes and types. In the future, this interface will also be extended to allow for SymbolRefAttr to be placed on other entities aside from just DictionaryAttr and ArrayAttr.

To limit potential test breakages, this revision only adds the new interfaces to the builtin attributes/types that are currently hardcoded during AsmPrinter alias generation. In a followup the remaining builtin attributes/types, and non-builtin attributes/types can be extended to support it.

Differential Revision: https://reviews.llvm.org/D102945

Added: 
    mlir/include/mlir/IR/SubElementInterfaces.h
    mlir/include/mlir/IR/SubElementInterfaces.td
    mlir/lib/IR/SubElementInterfaces.cpp
    mlir/unittests/IR/SubElementInterfaceTest.cpp

Modified: 
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/CMakeLists.txt
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/CMakeLists.txt
    mlir/unittests/IR/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index af75c2e9401a4..28ced2deb86f0 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -9,7 +9,7 @@
 #ifndef MLIR_IR_BUILTINATTRIBUTES_H
 #define MLIR_IR_BUILTINATTRIBUTES_H
 
-#include "mlir/IR/Attributes.h"
+#include "SubElementInterfaces.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Sequence.h"
 #include <complex>

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index cc2d3d78d3678..d514168bae367 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -15,6 +15,7 @@
 #define BUILTIN_ATTRIBUTES
 
 include "mlir/IR/BuiltinDialect.td"
+include "mlir/IR/SubElementInterfaces.td"
 
 // TODO: Currently the attributes defined in this file are prefixed with
 // `Builtin_`.  This is to 
diff erentiate the attributes here with the ones in
@@ -22,8 +23,9 @@ include "mlir/IR/BuiltinDialect.td"
 // to this file instead.
 
 // Base class for Builtin dialect attributes.
-class Builtin_Attr<string name, string baseCppClass = "::mlir::Attribute">
-    : AttrDef<Builtin_Dialect, name, [], baseCppClass> {
+class Builtin_Attr<string name, list<Trait> traits = [],
+                   string baseCppClass = "::mlir::Attribute">
+    : AttrDef<Builtin_Dialect, name, traits, baseCppClass> {
   let mnemonic = ?;
 }
 
@@ -62,7 +64,9 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap"> {
 // ArrayAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_ArrayAttr : Builtin_Attr<"Array"> {
+def Builtin_ArrayAttr : Builtin_Attr<"Array", [
+    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+  ]> {
   let summary = "A collection of other Attribute values";
   let description = [{
     Syntax:
@@ -133,7 +137,7 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_DenseIntOrFPElementsAttr
-    : Builtin_Attr<"DenseIntOrFPElements", "DenseElementsAttr"> {
+    : Builtin_Attr<"DenseIntOrFPElements", /*traits=*/[], "DenseElementsAttr"> {
   let summary = "An Attribute containing a dense multi-dimensional array of "
                 "integer or floating-point values";
   let description = [{
@@ -228,7 +232,7 @@ def Builtin_DenseIntOrFPElementsAttr
 //===----------------------------------------------------------------------===//
 
 def Builtin_DenseStringElementsAttr
-    : Builtin_Attr<"DenseStringElements", "DenseElementsAttr"> {
+    : Builtin_Attr<"DenseStringElements", /*traits=*/[], "DenseElementsAttr"> {
   let summary = "An Attribute containing a dense multi-dimensional array of "
                 "strings";
   let description = [{
@@ -277,7 +281,9 @@ def Builtin_DenseStringElementsAttr
 // DictionaryAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> {
+def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
+    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+  ]> {
   let summary = "An dictionary of named Attribute values";
   let description = [{
     Syntax:
@@ -589,7 +595,7 @@ def Builtin_OpaqueAttr : Builtin_Attr<"Opaque"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_OpaqueElementsAttr
-    : Builtin_Attr<"OpaqueElements", "ElementsAttr"> {
+    : Builtin_Attr<"OpaqueElements", /*traits=*/[], "ElementsAttr"> {
   let summary = "An opaque representation of a multi-dimensional array";
   let description = [{
     Syntax:
@@ -655,7 +661,7 @@ def Builtin_OpaqueElementsAttr
 //===----------------------------------------------------------------------===//
 
 def Builtin_SparseElementsAttr
-    : Builtin_Attr<"SparseElements", "ElementsAttr"> {
+    : Builtin_Attr<"SparseElements", /*traits=*/[], "ElementsAttr"> {
   let summary = "An opaque representation of a multi-dimensional array";
   let description = [{
     Syntax:
@@ -892,7 +898,9 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
 // TypeAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_TypeAttr : Builtin_Attr<"Type"> {
+def Builtin_TypeAttr : Builtin_Attr<"Type", [
+    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+  ]> {
   let summary = "An Attribute containing a Type";
   let description = [{
     Syntax:

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 89745db764c2e..8b30fa94f9936 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -9,8 +9,7 @@
 #ifndef MLIR_IR_BUILTINTYPES_H
 #define MLIR_IR_BUILTINTYPES_H
 
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Types.h"
+#include "SubElementInterfaces.h"
 
 namespace llvm {
 struct fltSemantics;

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 3ad6e1512bc7e..4edf72667bd6e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -16,14 +16,16 @@
 
 include "mlir/IR/BuiltinDialect.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/SubElementInterfaces.td"
 
 // TODO: Currently the types defined in this file are prefixed with `Builtin_`.
 // This is to 
diff erentiate the types here with the ones in OpBase.td. We should
 // remove the definitions in OpBase.td, and repoint users to this file instead.
 
 // Base class for Builtin dialect types.
-class Builtin_Type<string name, string baseCppClass = "::mlir::Type">
-    : TypeDef<Builtin_Dialect, name, [], baseCppClass> {
+class Builtin_Type<string name, list<Trait> traits = [],
+                   string baseCppClass = "::mlir::Type">
+    : TypeDef<Builtin_Dialect, name, traits, baseCppClass> {
   let mnemonic = ?;
 }
 
@@ -66,7 +68,8 @@ def Builtin_Complex : Builtin_Type<"Complex"> {
 //===----------------------------------------------------------------------===//
 
 // Base class for Builtin dialect float types.
-class Builtin_FloatType<string name> : Builtin_Type<name, "::mlir::FloatType"> {
+class Builtin_FloatType<string name>
+    : Builtin_Type<name, /*traits=*/[], "::mlir::FloatType"> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
   }];
@@ -118,7 +121,9 @@ def Builtin_Float128 : Builtin_FloatType<"Float128"> {
 // FunctionType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Function : Builtin_Type<"Function"> {
+def Builtin_Function : Builtin_Type<"Function", [
+    DeclareTypeInterfaceMethods<SubElementTypeInterface>
+  ]> {
   let summary = "Map from a list of inputs to a list of results";
   let description = [{
     Syntax:
@@ -253,7 +258,9 @@ def Builtin_Integer : Builtin_Type<"Integer"> {
 // MemRefType
 //===----------------------------------------------------------------------===//
 
-def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> {
+def Builtin_MemRef : Builtin_Type<"MemRef", [
+    DeclareTypeInterfaceMethods<SubElementTypeInterface>
+  ], "BaseMemRefType"> {
   let summary = "Shaped reference to a region of memory";
   let description = [{
     Syntax:
@@ -638,7 +645,9 @@ def Builtin_Opaque : Builtin_Type<"Opaque"> {
 // RankedTensorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "TensorType"> {
+def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
+    DeclareTypeInterfaceMethods<SubElementTypeInterface>
+  ], "TensorType"> {
   let summary = "Multi-dimensional array with a fixed number of dimensions";
   let description = [{
     Syntax:
@@ -726,7 +735,9 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "TensorType"> {
 // TupleType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Tuple : Builtin_Type<"Tuple"> {
+def Builtin_Tuple : Builtin_Type<"Tuple", [
+    DeclareTypeInterfaceMethods<SubElementTypeInterface>
+  ]> {
   let summary = "Fixed-sized collection of other types";
   let description = [{
     Syntax:
@@ -793,7 +804,9 @@ def Builtin_Tuple : Builtin_Type<"Tuple"> {
 // UnrankedMemRefType
 //===----------------------------------------------------------------------===//
 
-def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "BaseMemRefType"> {
+def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
+    DeclareTypeInterfaceMethods<SubElementTypeInterface>
+  ], "BaseMemRefType"> {
   let summary = "Shaped reference, with unknown rank, to a region of memory";
   let description = [{
     Syntax:
@@ -853,7 +866,9 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "BaseMemRefType"> {
 // UnrankedTensorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "TensorType"> {
+def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
+    DeclareTypeInterfaceMethods<SubElementTypeInterface>
+  ], "TensorType"> {
   let summary = "Multi-dimensional array with unknown dimensions";
   let description = [{
     Syntax:
@@ -890,7 +905,9 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "TensorType"> {
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Vector : Builtin_Type<"Vector", "ShapedType"> {
+def Builtin_Vector : Builtin_Type<"Vector", [
+    DeclareTypeInterfaceMethods<SubElementTypeInterface>
+  ], "ShapedType"> {
   let summary = "Multi-dimensional SIMD vector type";
   let description = [{
     Syntax:

diff  --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 8dafaa1ebdac9..1a6b9c942d3f6 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -31,6 +31,13 @@ mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)
 mlir_tablegen(BuiltinTypeInterfaces.cpp.inc -gen-type-interface-defs)
 add_public_tablegen_target(MLIRBuiltinTypeInterfacesIncGen)
 
+set(LLVM_TARGET_DEFINITIONS SubElementInterfaces.td)
+mlir_tablegen(SubElementAttrInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(SubElementAttrInterfaces.cpp.inc -gen-attr-interface-defs)
+mlir_tablegen(SubElementTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(SubElementTypeInterfaces.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIRSubElementInterfacesIncGen)
+
 set(LLVM_TARGET_DEFINITIONS TensorEncoding.td)
 mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls)
 mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs)

diff  --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/SubElementInterfaces.h
new file mode 100644
index 0000000000000..f5838b0467b91
--- /dev/null
+++ b/mlir/include/mlir/IR/SubElementInterfaces.h
@@ -0,0 +1,24 @@
+//===- SubElementInterfaces.h - Attr and Type SubElements -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains interfaces and utilities for querying the sub elements of
+// an attribute or type.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_SUBELEMENTINTERFACES_H
+#define MLIR_INTERFACES_SUBELEMENTINTERFACES_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Types.h"
+
+/// Include the definitions of the sub elemnt interfaces.
+#include "mlir/IR/SubElementAttrInterfaces.h.inc"
+#include "mlir/IR/SubElementTypeInterfaces.h.inc"
+
+#endif // MLIR_INTERFACES_SUBELEMENTINTERFACES_H

diff  --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td
new file mode 100644
index 0000000000000..8a4885237865b
--- /dev/null
+++ b/mlir/include/mlir/IR/SubElementInterfaces.td
@@ -0,0 +1,100 @@
+//===-- SubElementInterfaces.td - Sub-Element Interfaces ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a set of interfaces that can be used to interface with
+// sub-elements, e.g. held attributes and types, of a composite attribute or
+// type.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_SUBELEMENTINTERFACES_TD_
+#define MLIR_IR_SUBELEMENTINTERFACES_TD_
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// SubElementInterfaceBase
+//===----------------------------------------------------------------------===//
+
+class SubElementInterfaceBase<string interfaceName, string derivedValue> {
+  string cppNamespace = "::mlir";
+
+  list<InterfaceMethod> methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Walk all of the immediately nested sub-attributes and sub-types. This
+        method does not recurse into sub elements.
+      }], "void", "walkImmediateSubElements",
+      (ins "llvm::function_ref<void(mlir::Attribute)>":$walkAttrsFn,
+           "llvm::function_ref<void(mlir::Type)>":$walkTypesFn)
+    >,
+  ];
+
+  code extraClassDeclaration = [{
+    /// Walk all of the held sub-attributes.
+    void walkSubAttrs(llvm::function_ref<void(mlir::Attribute)> walkFn) {
+      walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {});
+    }
+
+    /// Walk all of the held sub-types.
+    void walkSubTypes(llvm::function_ref<void(mlir::Type)> walkFn) {
+      walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn);
+    }
+
+    /// Walk all of the held sub-attributes and sub-types.
+    void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
+                         llvm::function_ref<void(mlir::Type)> walkTypesFn);
+  }];
+
+  code extraTraitClassDeclaration = [{
+    /// Walk all of the held sub-attributes.
+    void walkSubAttrs(llvm::function_ref<void(mlir::Attribute)> walkFn) {
+      walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {});
+    }
+
+    /// Walk all of the held sub-types.
+    void walkSubTypes(llvm::function_ref<void(mlir::Type)> walkFn) {
+      walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn);
+    }
+
+    /// Walk all of the held sub-attributes and sub-types.
+    void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
+                         llvm::function_ref<void(mlir::Type)> walkTypesFn) {
+      }] # interfaceName # " interface(" # derivedValue # [{);
+      interface.walkSubElements(walkAttrsFn, walkTypesFn);
+    }
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// SubElementAttrInterface
+//===----------------------------------------------------------------------===//
+
+def SubElementAttrInterface
+    : AttrInterface<"SubElementAttrInterface">,
+      SubElementInterfaceBase<"SubElementAttrInterface", "$_attr"> {
+  let description = [{
+    An interface used to query and manipulate sub-elements, such as sub-types
+    and sub-attributes of a composite attribute.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// SubElementTypeInterface
+//===----------------------------------------------------------------------===//
+
+def SubElementTypeInterface
+    : TypeInterface<"SubElementTypeInterface">,
+      SubElementInterfaceBase<"SubElementTypeInterface", "$_type"> {
+  let description = [{
+    An interface used to query and manipulate sub-elements, such as sub-types
+    and sub-attributes of a composite type.
+  }];
+}
+
+#endif // MLIR_IR_SUBELEMENTINTERFACES_TD_

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index dace47a9e8861..4356b145e52b2 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/SubElementInterfaces.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/MapVector.h"
@@ -626,14 +627,10 @@ void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
     return;
   }
 
-  if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
-    for (Attribute element : arrayAttr.getValue())
-      visit(element);
-  } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
-    for (const NamedAttribute &attr : dictAttr)
-      visit(attr.second);
-  } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
-    visit(typeAttr.getValue());
+  // Check for any sub elements.
+  if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) {
+    subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
+                                        [&](Type type) { visit(type); });
   }
 }
 
@@ -645,20 +642,10 @@ void AliasInitializer::visit(Type type) {
   if (succeeded(generateAlias(type, aliasToType)))
     return;
 
-  // Visit several subtypes that contain types or attributes.
-  if (auto funcType = type.dyn_cast<FunctionType>()) {
-    // Visit input and result types for functions.
-    for (auto input : funcType.getInputs())
-      visit(input);
-    for (auto result : funcType.getResults())
-      visit(result);
-  } else if (auto shapedType = type.dyn_cast<ShapedType>()) {
-    visit(shapedType.getElementType());
-
-    // Visit affine maps in memref type.
-    if (auto memref = type.dyn_cast<MemRefType>())
-      for (auto map : memref.getAffineMaps())
-        visit(AffineMapAttr::get(map));
+  // Check for any sub elements.
+  if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) {
+    subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
+                                        [&](Type type) { visit(type); });
   }
 }
 

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 79d2e2f3f409d..763ab803e6042 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -42,6 +42,17 @@ void BuiltinDialect::registerAttributes() {
                 UnitAttr>();
 }
 
+//===----------------------------------------------------------------------===//
+// ArrayAttr
+//===----------------------------------------------------------------------===//
+
+void ArrayAttr::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  for (Attribute attr : getValue())
+    walkAttrsFn(attr);
+}
+
 //===----------------------------------------------------------------------===//
 // DictionaryAttr
 //===----------------------------------------------------------------------===//
@@ -197,6 +208,13 @@ DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
   return Base::get(context, ArrayRef<NamedAttribute>());
 }
 
+void DictionaryAttr::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  for (Attribute attr : llvm::make_second_range(getValue()))
+    walkAttrsFn(attr);
+}
+
 //===----------------------------------------------------------------------===//
 // StringAttr
 //===----------------------------------------------------------------------===//
@@ -1370,3 +1388,13 @@ std::vector<ptr
diff _t> SparseElementsAttr::getFlattenedSparseIndices() const {
         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
   return flatSparseIndices;
 }
+
+//===----------------------------------------------------------------------===//
+// TypeAttr
+//===----------------------------------------------------------------------===//
+
+void TypeAttr::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  walkTypesFn(getValue());
+}

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 77d64080de6e2..d5fd1eadbb69f 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -199,6 +199,13 @@ FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
   return get(getContext(), newInputTypes, newResultTypes);
 }
 
+void FunctionType::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
+    walkTypesFn(type);
+}
+
 //===----------------------------------------------------------------------===//
 // OpaqueType
 //===----------------------------------------------------------------------===//
@@ -419,6 +426,12 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
   return VectorType();
 }
 
+void VectorType::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  walkTypesFn(getElementType());
+}
+
 //===----------------------------------------------------------------------===//
 // TensorType
 //===----------------------------------------------------------------------===//
@@ -459,6 +472,12 @@ RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
   return checkTensorElementType(emitError, elementType);
 }
 
+void RankedTensorType::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  walkTypesFn(getElementType());
+}
+
 //===----------------------------------------------------------------------===//
 // UnrankedTensorType
 //===----------------------------------------------------------------------===//
@@ -469,6 +488,12 @@ UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
   return checkTensorElementType(emitError, elementType);
 }
 
+void UnrankedTensorType::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  walkTypesFn(getElementType());
+}
+
 //===----------------------------------------------------------------------===//
 // BaseMemRefType
 //===----------------------------------------------------------------------===//
@@ -612,6 +637,15 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+void MemRefType::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  walkTypesFn(getElementType());
+  walkAttrsFn(getMemorySpace());
+  for (AffineMap map : getAffineMaps())
+    walkAttrsFn(AffineMapAttr::get(map));
+}
+
 //===----------------------------------------------------------------------===//
 // UnrankedMemRefType
 //===----------------------------------------------------------------------===//
@@ -779,6 +813,13 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
   return success();
 }
 
+void UnrankedMemRefType::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  walkTypesFn(getElementType());
+  walkAttrsFn(getMemorySpace());
+}
+
 //===----------------------------------------------------------------------===//
 /// TupleType
 //===----------------------------------------------------------------------===//
@@ -802,6 +843,13 @@ void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
 /// Return the number of element types.
 size_t TupleType::size() const { return getImpl()->size(); }
 
+void TupleType::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  for (Type type : getTypes())
+    walkTypesFn(type);
+}
+
 //===----------------------------------------------------------------------===//
 // Type Utilities
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index a04bc3522a722..3f5e59d13b085 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_library(MLIRIR
   PatternMatch.cpp
   Region.cpp
   RegionKindInterface.cpp
+  SubElementInterfaces.cpp
   SymbolTable.cpp
   TensorEncoding.cpp
   Types.cpp
@@ -46,6 +47,7 @@ add_mlir_library(MLIRIR
   MLIROpAsmInterfaceIncGen
   MLIRRegionKindInterfaceIncGen
   MLIRSideEffectInterfacesIncGen
+  MLIRSubElementInterfacesIncGen
   MLIRSymbolInterfacesIncGen
   MLIRTensorEncodingIncGen
 

diff  --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp
new file mode 100644
index 0000000000000..0a4875cf7d8aa
--- /dev/null
+++ b/mlir/lib/IR/SubElementInterfaces.cpp
@@ -0,0 +1,65 @@
+//===- SubElementInterfaces.cpp - Attr and Type SubElement Interfaces -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/SubElementInterfaces.h"
+
+using namespace mlir;
+
+template <typename InterfaceT>
+static void walkSubElementsImpl(InterfaceT interface,
+                                function_ref<void(Attribute)> walkAttrsFn,
+                                function_ref<void(Type)> walkTypesFn) {
+  interface.walkImmediateSubElements(
+      [&](Attribute attr) {
+        // Guard against potentially null inputs. This removes the need for the
+        // derived attribute/type to do it.
+        if (!attr)
+          return;
+
+        // Walk any sub elements first.
+        if (auto interface = attr.dyn_cast<SubElementAttrInterface>())
+          walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
+
+        // Walk this attribute.
+        walkAttrsFn(attr);
+      },
+      [&](Type type) {
+        // Guard against potentially null inputs. This removes the need for the
+        // derived attribute/type to do it.
+        if (!type)
+          return;
+
+        // Walk any sub elements first.
+        if (auto interface = type.dyn_cast<SubElementTypeInterface>())
+          walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
+
+        // Walk this type.
+        walkTypesFn(type);
+      });
+}
+
+void SubElementAttrInterface::walkSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) {
+  assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
+  walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
+}
+
+void SubElementTypeInterface::walkSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) {
+  assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
+  walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
+}
+
+//===----------------------------------------------------------------------===//
+// SubElementInterface Tablegen definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/SubElementAttrInterfaces.cpp.inc"
+#include "mlir/IR/SubElementTypeInterfaces.cpp.inc"

diff  --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 81be1831ff0b7..f9cfed62714e1 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_unittest(MLIRIRTests
   MemRefTypeTest.cpp
   OperationSupportTest.cpp
   ShapedTypeTest.cpp
+  SubElementInterfaceTest.cpp
 )
 target_link_libraries(MLIRIRTests
   PRIVATE

diff  --git a/mlir/unittests/IR/SubElementInterfaceTest.cpp b/mlir/unittests/IR/SubElementInterfaceTest.cpp
new file mode 100644
index 0000000000000..78356fd9eaf49
--- /dev/null
+++ b/mlir/unittests/IR/SubElementInterfaceTest.cpp
@@ -0,0 +1,35 @@
+//===- SubElementInterfaceTest.cpp - SubElementInterface unit tests -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/SubElementInterfaces.h"
+#include "gtest/gtest.h"
+#include <cstdint>
+
+using namespace mlir;
+using namespace mlir::detail;
+
+namespace {
+TEST(SubElementInterfaceTest, Nested) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  BoolAttr trueAttr = builder.getBoolAttr(true);
+  BoolAttr falseAttr = builder.getBoolAttr(false);
+  ArrayAttr boolArrayAttr = builder.getArrayAttr({trueAttr, falseAttr});
+  DictionaryAttr dictAttr =
+      builder.getDictionaryAttr(builder.getNamedAttr("array", boolArrayAttr));
+
+  SmallVector<Attribute> subAttrs;
+  dictAttr.walkSubAttrs([&](Attribute attr) { subAttrs.push_back(attr); });
+  EXPECT_EQ(llvm::makeArrayRef(subAttrs),
+            ArrayRef<Attribute>({trueAttr, falseAttr, boolArrayAttr}));
+}
+
+} // end namespace


        


More information about the Mlir-commits mailing list