[Mlir-commits] [mlir] e179532 - [mlir] Remove types from attributes

Jeff Niu llvmlistbot at llvm.org
Sun Jul 31 17:01:37 PDT 2022


Author: Jeff Niu
Date: 2022-07-31T20:01:31-04:00
New Revision: e1795322844ca45ecbcdca8669929a46c666127e

URL: https://github.com/llvm/llvm-project/commit/e1795322844ca45ecbcdca8669929a46c666127e
DIFF: https://github.com/llvm/llvm-project/commit/e1795322844ca45ecbcdca8669929a46c666127e.diff

LOG: [mlir] Remove types from attributes

This patch removes the `type` field from `Attribute` along with the
`Attribute::getType` accessor.

Going forward, this means that attributes in MLIR will no longer have
types as a first-class concept. This patch lays the groundwork to
incrementally remove or refactor code that relies on generic attributes
being typed. The immediate impact will be on attributes that rely on
`Attribute` containing a type, such as `IntegerAttr`,
`DenseElementsAttr`, and `ml_program::ExternAttr`, which will now need
to define a type parameter on their storage classes. This will save
memory as all other attribute kinds will no longer contain a type.

Moreover, it will not be possible to generically query the type of an
attribute directly. This patch provides an attribute interface
`TypedAttr` that implements only one method, `getType`, which can be
used to generically query the types of attributes that implement the
interface. This interface can be used to retain the concept of a "typed
attribute". The ODS-generated accessor for a `type` parameter
automatically implements this method.

Next steps will be to refactor the assembly formats of certain operations
that rely on `parseAttribute(type)` and `printAttributeWithoutType` to
remove special handling of type elision until `type` can be removed from
the dialect parsing hook entirely; and incrementally remove uses of
`TypedAttr`.

Reviewed By: lattner, rriddle, jpienaar

Differential Revision: https://reviews.llvm.org/D130092

Added: 
    mlir/include/mlir/IR/BuiltinTypeInterfaces.h

Modified: 
    mlir/docs/AttributesAndTypes.md
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/include/mlir/Dialect/CommonFolders.h
    mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td
    mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
    mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
    mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h
    mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
    mlir/include/mlir/IR/AttrTypeBase.td
    mlir/include/mlir/IR/AttributeSupport.h
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
    mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/lib/AsmParser/DialectSymbolParser.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
    mlir/lib/Dialect/EmitC/IR/EmitC.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/AttributeDetail.h
    mlir/lib/IR/BuiltinAttributeInterfaces.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/TypeUtilities.cpp
    mlir/lib/Rewrite/ByteCode.cpp
    mlir/lib/TableGen/AttrOrTypeDef.cpp
    mlir/lib/Target/Cpp/TranslateToCpp.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
    mlir/test/IR/file-metadata-resources.mlir
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestAttributes.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/attr-or-type-format.td
    mlir/test/mlir-tblgen/attrdefs.td
    mlir/test/mlir-tblgen/op-result.td
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/unittests/IR/AttributeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md
index df650e5f9a7e9..caf0078dd2b4b 100644
--- a/mlir/docs/AttributesAndTypes.md
+++ b/mlir/docs/AttributesAndTypes.md
@@ -126,9 +126,9 @@ def My_IntegerAttr : MyDialect_Attr<"Integer", "int"> {
     An integer attribute is a literal attribute that represents an integral
     value of the specified integer type.
   }];
-  /// Here we've defined two parameters, one is the `self` type of the attribute
-  /// (i.e. the type of the Attribute itself), and the other is the integer value
-  /// of the attribute.
+  /// Here we've defined two parameters, one is a "self" type parameter, and the
+  /// other is the integer value of the attribute. The self type parameter is
+  /// specially handled by the assembly format.
   let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value);
 
   /// Here we've defined a custom builder for the type, that removes the need to pass
@@ -146,6 +146,8 @@ def My_IntegerAttr : MyDialect_Attr<"Integer", "int"> {
   ///
   ///    #my.int<50> : !my.int<32> // a 32-bit integer of value 50.
   ///
+  /// Note that the self type parameter is not included in the assembly format.
+  /// Its value is derived from the optional trailing type on all attributes.
   let assemblyFormat = "`<` $value `>`";
 
   /// Indicate that our attribute will add additional verification to the parameters.
@@ -271,9 +273,8 @@ MLIR includes several specialized classes for common situations:
 - `ArrayRefOfSelfAllocationParameter<arrayOf, descriptionOfParam>` for arrays of
   objects which self-allocate as per the last specialization.
 
-- `AttributeSelfTypeParameter` is a special AttrParameter that corresponds to
-  the `Type` of the attribute. Only one parameter of the attribute may be of
-  this parameter type.
+- `AttributeSelfTypeParameter` is a special `AttrParameter` that represents
+  parameters derived from the optional trailing type on attributes.
 
 ### Traits
 
@@ -702,6 +703,54 @@ available through `$_ctxt`. E.g.
 DefaultValuedParameter<"IntegerType", "IntegerType::get($_ctxt, 32)">
 ```
 
+The value of parameters that appear __before__ the default-valued parameter in
+the parameter declaration list are available as substitutions. E.g.
+
+```tablegen
+let parameters = (ins
+  "IntegerAttr":$value,
+  DefaultValuedParameter<"Type", "$value.getType()">:$type
+);
+```
+
+###### Attribute Self Type Parameter
+
+An attribute optionally has a trailing type after the assembly format of the
+attribute value itself. MLIR parses over the attribute value and optionally
+parses a colon-type before passing the `Type` into the dialect parser hook.
+
+```
+dialect-attribute  ::= `#` dialect-namespace `<` attr-data `>`
+                       (`:` type)?
+                     | `#` alias-name pretty-dialect-sym-body? (`:` type)?
+```
+
+`AttributeSelfTypeParameter` is an attribute parameter specially handled by the
+assembly format generator. Only one such parameter can be specified, and its
+value is derived from the trailing type. This parameter's default value is
+`NoneType::get($_ctxt)`.
+
+In order for the type to be printed by
+MLIR, however, the attribute must implement `TypedAttrInterface`. For example,
+
+```tablegen
+// This attribute has only a self type parameter.
+def MyExternAttr : AttrDef<MyDialect, "MyExtern", [TypedAttrInterface]> {
+  let parameters = (AttributeSelfTypeParameter<"">:$type);
+  let mnemonic = "extern";
+  let assemblyFormat = "";
+}
+```
+
+This attribute can look like:
+
+```mlir
+#my_dialect.extern // none
+#my_dialect.extern : i32
+#my_dialect.extern : tensor<4xi32>
+#my_dialect.extern : !my_dialect.my_type
+```
+
 ##### Assembly Format Directives
 
 Attribute and type assembly formats have the following directives:

diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 75710d60c6d45..a36ea682de312 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -15,6 +15,7 @@ include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 
 // Base class for Arithmetic dialect ops. Ops in this dialect have no side
@@ -147,7 +148,7 @@ def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
     ```
   }];
 
-  let arguments = (ins AnyAttr:$value);
+  let arguments = (ins TypedAttrInterface:$value);
   // TODO: Disallow arith.constant to return anything other than a signless
   // integer or float like. Downstream users of Arithmetic should only be
   // working with signless integers, floats, or vectors/tensors thereof.

diff  --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 868089361e36b..986990ed0c47f 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -32,12 +32,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
   assert(operands.size() == 2 && "binary op takes two operands");
   if (!operands[0] || !operands[1])
     return {};
-  if (operands[0].getType() != operands[1].getType())
-    return {};
 
   if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
     auto lhs = operands[0].cast<AttrElementT>();
     auto rhs = operands[1].cast<AttrElementT>();
+    if (lhs.getType() != rhs.getType())
+      return {};
 
     auto calRes = calculate(lhs.getValue(), rhs.getValue());
 
@@ -53,6 +53,8 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
     // just fold based on the splat value.
     auto lhs = operands[0].cast<SplatElementsAttr>();
     auto rhs = operands[1].cast<SplatElementsAttr>();
+    if (lhs.getType() != rhs.getType())
+      return {};
 
     auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
                                    rhs.getSplatValue<ElementValueT>());
@@ -66,6 +68,8 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
     // expanding the values.
     auto lhs = operands[0].cast<ElementsAttr>();
     auto rhs = operands[1].cast<ElementsAttr>();
+    if (lhs.getType() != rhs.getType())
+      return {};
 
     auto lhsIt = lhs.value_begin<ElementValueT>();
     auto rhsIt = rhs.value_begin<ElementValueT>();

diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td
index afae94c41ad8b..886ae3a3f721f 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td
@@ -10,18 +10,21 @@
 #define COMPLEX_ATTRIBUTE
 
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/Dialect/Complex/IR/ComplexBase.td"
 
 //===----------------------------------------------------------------------===//
 // Complex Attributes.
 //===----------------------------------------------------------------------===//
 
-class Complex_Attr<string attrName, string attrMnemonic>
-    : AttrDef<Complex_Dialect, attrName> {
+class Complex_Attr<string attrName, string attrMnemonic,
+                   list<Trait> traits = []>
+    : AttrDef<Complex_Dialect, attrName, traits> {
   let mnemonic = attrMnemonic;
 }
 
-def Complex_NumberAttr : Complex_Attr<"Number", "number"> {
+def Complex_NumberAttr : Complex_Attr<"Number", "number",
+                                      [TypedAttrInterface]> {
   let summary = "A complex number attribute";
 
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 6ddf10e7935dd..ac86e92c7c348 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -139,7 +139,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
     ```
   }];
 
-  let arguments = (ins AnyAttr:$value);
+  let arguments = (ins TypedAttrInterface:$value);
   let results = (outs AnyType);
 
   let hasFolder = 1;
@@ -212,7 +212,7 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
     ```
   }];
 
-  let arguments = (ins AnyAttr:$value);
+  let arguments = (ins TypedAttrInterface:$value);
   let results = (outs AnyType);
 
   let hasVerifier = 1;

diff  --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
index d9f724ae5b06a..22b2f55872d90 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
@@ -14,18 +14,19 @@
 #define MLIR_DIALECT_EMITC_IR_EMITCATTRIBUTES
 
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/Dialect/EmitC/IR/EmitCBase.td"
 
 //===----------------------------------------------------------------------===//
 // EmitC attribute definitions
 //===----------------------------------------------------------------------===//
 
-class EmitC_Attr<string name, string attrMnemonic>
-    : AttrDef<EmitC_Dialect, name> {
+class EmitC_Attr<string name, string attrMnemonic, list<Trait> traits = []>
+    : AttrDef<EmitC_Dialect, name, traits> {
   let mnemonic = attrMnemonic;
 }
 
-def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> {
+def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque", [TypedAttrInterface]> {
   let summary = "An opaque attribute";
 
   let description = [{
@@ -40,8 +41,9 @@ def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> {
     ```
   }];
 
-  let parameters = (ins StringRefParameter<"the opaque value">:$value);
-  
+  let parameters = (ins "Type":$type,
+                        StringRefParameter<"the opaque value">:$value);
+
   let hasCustomAssemblyFormat = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h
index 253daedcc605a..3836684726cb3 100644
--- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_
 
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
 
 //===----------------------------------------------------------------------===//
 // Tablegen Attribute Declarations

diff  --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td
index 1323afec2e9a3..eb6e293bbf4f6 100644
--- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td
@@ -10,6 +10,7 @@
 #define MLPROGRAM_ATTRIBUTES
 
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/Dialect/MLProgram/IR/MLProgramBase.td"
 
 // Base class for MLProgram dialect attributes.
@@ -22,7 +23,7 @@ class MLProgram_Attr<string name, list<Trait> traits = []>
 // ExternAttr
 //===----------------------------------------------------------------------===//
 
-def MLProgram_ExternAttr : MLProgram_Attr<"Extern"> {
+def MLProgram_ExternAttr : MLProgram_Attr<"Extern", [TypedAttrInterface]> {
   let summary = "Value used for a global signalling external resolution";
   let description = [{
   When used as the value for a GlobalOp, this indicates that the actual

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 7f7c686a24632..d52378c6f0156 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -16,6 +16,7 @@
 #define MLIR_DIALECT_SPIRV_IR_STRUCTURE_OPS
 
 include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/FunctionInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
@@ -600,7 +601,7 @@ def SPV_SpecConstantOp : SPV_Op<"SpecConstant", [InModuleScope, Symbol]> {
 
   let arguments = (ins
     StrAttr:$sym_name,
-    AnyAttr:$default_value
+    TypedAttrInterface:$default_value
   );
 
   let results = (outs);

diff  --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 93dd128d329b0..5bf1dfe3239b0 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -257,14 +257,6 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
   let convertFromStorage = "$_self.cast<" # dialect.cppNamespace #
                                  "::" # cppClassName # ">()";
 
-  // A code block used to build the value 'Type' of an Attribute when
-  // initializing its storage instance. This field is optional, and if not
-  // present the attribute will have its value type set to `NoneType`. This code
-  // block may reference any of the attributes parameters via
-  // `$_<parameter-name`. If one of the parameters of the attribute is of type
-  // `AttributeSelfTypeParameter`, this field is ignored.
-  code typeBuilder = ?;
-
   // The predicate for when this def is used as a constraint.
   let predicate = CPred<"$_self.isa<" # dialect.cppNamespace #
                                  "::" # cppClassName # ">()">;
@@ -334,7 +326,7 @@ class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
   // which by default is the C++ equality operator. The current MLIR context is
   // made available through `$_ctxt`, e.g., for constructing default values for
   // attributes and types.
-  string defaultValue = ?;
+  string defaultValue = "";
 }
 class AttrParameter<string type, string desc, string accessorType = "">
  : AttrOrTypeParameter<type, desc, accessorType>;
@@ -392,11 +384,21 @@ class DefaultValuedParameter<string type, string value, string desc = ""> :
   let defaultValue = value;
 }
 
-// This is a special parameter used for AttrDefs that represents a `mlir::Type`
-// that is also used as the value `Type` of the attribute. Only one parameter
-// of the attribute may be of this type.
+// This is a special attribute parameter that represents the "self" type of the
+// attribute. It is specially handled by the assembly format generator to derive
+// its value from the optional trailing type after each attribute.
+//
+// By default, the self type parameter is optional and has a default value of
+// `none`. If a derived type other than `::mlir::Type` is specified, the
+// parameter loses its default value unless another one is specified by
+// `typeBuilder`.
 class AttributeSelfTypeParameter<string desc,
-                                 string derivedType = "::mlir::Type"> :
-    AttrOrTypeParameter<derivedType, desc> {}
+                                 string derivedType = "::mlir::Type",
+                                 string typeBuilder = ""> :
+    AttrOrTypeParameter<derivedType, desc> {
+  let defaultValue = !if(!and(!empty(typeBuilder),
+                              !eq(derivedType, "::mlir::Type")),
+                         "::mlir::NoneType::get($_ctxt)", typeBuilder);
+}
 
 #endif // ATTRTYPEBASE_TD

diff  --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index 948eaf8327dc3..58e37f09e1c6c 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -129,9 +129,6 @@ class alignas(8) AttributeStorage : public StorageUniquer::BaseStorage {
   friend StorageUniquer;
 
 public:
-  /// Get the type of this attribute.
-  Type getType() const { return type; }
-
   /// Return the abstract descriptor for this attribute.
   const AbstractAttribute &getAbstractAttribute() const {
     assert(abstractAttribute && "Malformed attribute storage object.");
@@ -139,15 +136,6 @@ class alignas(8) AttributeStorage : public StorageUniquer::BaseStorage {
   }
 
 protected:
-  /// Construct a new attribute storage instance with the given type.
-  /// Note: All attributes require a valid type. If no type is provided here,
-  ///       the type of the attribute will automatically default to NoneType
-  ///       upon initialization in the uniquer.
-  AttributeStorage(Type type = nullptr) : type(type) {}
-
-  /// Set the type of this attribute.
-  void setType(Type newType) { type = newType; }
-
   /// Set the abstract attribute for this storage instance. This is used by the
   /// AttributeUniquer when initializing a newly constructed storage object.
   void initializeAbstractAttribute(const AbstractAttribute &abstractAttr) {
@@ -159,9 +147,6 @@ class alignas(8) AttributeStorage : public StorageUniquer::BaseStorage {
   void initialize(MLIRContext *context) {}
 
 private:
-  /// The type of the attribute value.
-  Type type;
-
   /// The abstract descriptor for this attribute.
   const AbstractAttribute *abstractAttribute = nullptr;
 };

diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 1c38db9933d41..6ebb0449da336 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -66,9 +66,6 @@ class Attribute {
   /// to support dynamic type casting.
   TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); }
 
-  /// Return the type of this attribute.
-  Type getType() const { return impl->getType(); }
-
   /// Return the context this attribute belongs to.
   MLIRContext *getContext() const;
 

diff  --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
index e4da88f300b44..bd0d468670497 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
@@ -11,6 +11,7 @@
 
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/Any.h"
@@ -18,7 +19,6 @@
 #include <complex>
 
 namespace mlir {
-class ShapedType;
 
 //===----------------------------------------------------------------------===//
 // ElementsAttr
@@ -237,10 +237,10 @@ class ElementsAttrRange : public llvm::iterator_range<IteratorT> {
 public:
   using reference = typename IteratorT::reference;
 
-  ElementsAttrRange(Type shapeType,
+  ElementsAttrRange(ShapedType shapeType,
                     const llvm::iterator_range<IteratorT> &range)
       : llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {}
-  ElementsAttrRange(Type shapeType, IteratorT beginIt, IteratorT endIt)
+  ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
       : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {}
 
   /// Return the value at the given index.
@@ -254,7 +254,7 @@ class ElementsAttrRange : public llvm::iterator_range<IteratorT> {
 
 private:
   /// The shaped type of the parent ElementsAttr.
-  Type shapeType;
+  ShapedType shapeType;
 };
 
 } // namespace detail

diff  --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index 45295e874f3bd..58dad618d3cee 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -154,7 +154,10 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     }], "bool", "isSplat", (ins), /*defaultImplementation=*/[{}], [{
         // By default, only check for a single element splat.
         return $_attr.getNumElements() == 1;
-    }]>
+    }]>,
+    InterfaceMethod<[{
+      Returns the shaped type of the elements attribute.
+    }], "::mlir::ShapedType", "getType">
   ];
 
   string ElementsAttrInterfaceAccessors = [{
@@ -280,7 +283,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     auto getValues() const {
       auto beginIt = $_attr.template value_begin<T>();
       return detail::ElementsAttrRange<decltype(beginIt)>(
-        Attribute($_attr).getType(), beginIt, std::next(beginIt, size()));
+        $_attr.getType(), beginIt, std::next(beginIt, size()));
     }
   }] # ElementsAttrInterfaceAccessors;
 
@@ -294,19 +297,17 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     // Accessors
     //===------------------------------------------------------------------===//
 
-    /// Return the type of this attribute.
-    ShapedType getType() const;
-
     /// Return the element type of this ElementsAttr.
     Type getElementType() const { return getElementType(*this); }
-    static Type getElementType(Attribute elementsAttr);
+    static Type getElementType(ElementsAttr elementsAttr);
 
     /// Return if the given 'index' refers to a valid element in this attribute.
     bool isValidIndex(ArrayRef<uint64_t> index) const {
       return isValidIndex(*this, index);
     }
     static bool isValidIndex(ShapedType type, ArrayRef<uint64_t> index);
-    static bool isValidIndex(Attribute elementsAttr, ArrayRef<uint64_t> index);
+    static bool isValidIndex(ElementsAttr elementsAttr,
+                             ArrayRef<uint64_t> index);
 
     /// Return the 1 dimensional flattened row-major index from the given
     /// multi-dimensional index.
@@ -315,14 +316,14 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     }
     static uint64_t getFlattenedIndex(Type type,
                                       ArrayRef<uint64_t> index);
-    static uint64_t getFlattenedIndex(Attribute elementsAttr,
+    static uint64_t getFlattenedIndex(ElementsAttr elementsAttr,
                                       ArrayRef<uint64_t> index) {
       return getFlattenedIndex(elementsAttr.getType(), index);
     }
 
     /// Returns the number of elements held by this attribute.
     int64_t getNumElements() const { return getNumElements(*this); }
-    static int64_t getNumElements(Attribute elementsAttr);
+    static int64_t getNumElements(ElementsAttr elementsAttr);
 
     //===------------------------------------------------------------------===//
     // Value Iteration
@@ -349,7 +350,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     /// Return the elements of this attribute as a value of type 'T'.
     template <typename T>
     DefaultValueCheckT<T, iterator_range<T>> getValues() const {
-      return {Attribute::getType(), value_begin<T>(), value_end<T>()};
+      return {getType(), value_begin<T>(), value_end<T>()};
     }
     template <typename T>
     DefaultValueCheckT<T, iterator<T>> value_begin() const;
@@ -369,8 +370,8 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     template <typename T, typename = DerivedAttrValueCheckT<T>>
     DerivedAttrValueIteratorRange<T> getValues() const {
       auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
-      return {Attribute::getType(), llvm::map_range(getValues<Attribute>(),
-                             static_cast<T (*)(Attribute)>(castFn))};
+      return {getType(), llvm::map_range(getValues<Attribute>(),
+              static_cast<T (*)(Attribute)>(castFn))};
     }
     template <typename T, typename = DerivedAttrValueCheckT<T>>
     DerivedAttrValueIterator<T> value_begin() const {
@@ -388,10 +389,8 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     /// return the iterable range. Otherwise, return llvm::None.
     template <typename T>
     DefaultValueCheckT<T, Optional<iterator_range<T>>> tryGetValues() const {
-      if (Optional<iterator<T>> beginIt = try_value_begin<T>()) {
-        return iterator_range<T>(Attribute::getType(), *beginIt,
-                                 value_end<T>());
-      }
+      if (Optional<iterator<T>> beginIt = try_value_begin<T>())
+        return iterator_range<T>(getType(), *beginIt, value_end<T>());
       return llvm::None;
     }
     template <typename T>
@@ -407,7 +406,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
 
       auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
       return DerivedAttrValueIteratorRange<T>(
-        Attribute::getType(),
+        getType(),
         llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn))
       );
     }
@@ -468,4 +467,23 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// TypedAttrInterface
+//===----------------------------------------------------------------------===//
+
+def TypedAttrInterface : AttrInterface<"TypedAttr"> {
+  let cppNamespace = "::mlir";
+
+  let description = [{
+    This interface is used for attributes that have a type. The type of an
+    attribute is understood to represent the type of the data contained in the
+    attribute and is often used as the type of a value with this data.
+  }];
+
+  let methods = [InterfaceMethod<
+    "Get the attribute's type",
+    "::mlir::Type", "getType"
+  >];
+}
+
 #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 29f377b214e6e..7adec3305a48c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -25,7 +25,6 @@ class IntegerSet;
 class IntegerType;
 class Location;
 class Operation;
-class ShapedType;
 
 //===----------------------------------------------------------------------===//
 // Elements Attributes
@@ -402,7 +401,7 @@ class DenseElementsAttr : public Attribute {
                              std::numeric_limits<T>::is_signed));
     const char *rawData = getRawData().data();
     bool splat = isSplat();
-    return {Attribute::getType(), ElementIterator<T>(rawData, splat, 0),
+    return {getType(), ElementIterator<T>(rawData, splat, 0),
             ElementIterator<T>(rawData, splat, getNumElements())};
   }
   template <typename T, typename = IntFloatValueTemplateCheckT<T>>
@@ -431,7 +430,7 @@ class DenseElementsAttr : public Attribute {
                           std::numeric_limits<ElementT>::is_signed));
     const char *rawData = getRawData().data();
     bool splat = isSplat();
-    return {Attribute::getType(), ElementIterator<T>(rawData, splat, 0),
+    return {getType(), ElementIterator<T>(rawData, splat, 0),
             ElementIterator<T>(rawData, splat, getNumElements())};
   }
   template <typename T, typename ElementT = typename T::value_type,
@@ -458,7 +457,7 @@ class DenseElementsAttr : public Attribute {
     auto stringRefs = getRawStringData();
     const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
     bool splat = isSplat();
-    return {Attribute::getType(), ElementIterator<StringRef>(ptr, splat, 0),
+    return {getType(), ElementIterator<StringRef>(ptr, splat, 0),
             ElementIterator<StringRef>(ptr, splat, getNumElements())};
   }
   template <typename T, typename = StringRefValueTemplateCheckT<T>>
@@ -478,8 +477,7 @@ class DenseElementsAttr : public Attribute {
       typename std::enable_if<std::is_same<T, Attribute>::value>::type;
   template <typename T, typename = AttributeValueTemplateCheckT<T>>
   iterator_range_impl<AttributeElementIterator> getValues() const {
-    return {Attribute::getType(), value_begin<Attribute>(),
-            value_end<Attribute>()};
+    return {getType(), value_begin<Attribute>(), value_end<Attribute>()};
   }
   template <typename T, typename = AttributeValueTemplateCheckT<T>>
   AttributeElementIterator value_begin() const {
@@ -510,7 +508,7 @@ class DenseElementsAttr : public Attribute {
   template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
   iterator_range_impl<DerivedAttributeElementIterator<T>> getValues() const {
     using DerivedIterT = DerivedAttributeElementIterator<T>;
-    return {Attribute::getType(), DerivedIterT(value_begin<Attribute>()),
+    return {getType(), DerivedIterT(value_begin<Attribute>()),
             DerivedIterT(value_end<Attribute>())};
   }
   template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
@@ -530,7 +528,7 @@ class DenseElementsAttr : public Attribute {
   template <typename T, typename = BoolValueTemplateCheckT<T>>
   iterator_range_impl<BoolElementIterator> getValues() const {
     assert(isValidBool() && "bool is not the value of this elements attribute");
-    return {Attribute::getType(), BoolElementIterator(*this, 0),
+    return {getType(), BoolElementIterator(*this, 0),
             BoolElementIterator(*this, getNumElements())};
   }
   template <typename T, typename = BoolValueTemplateCheckT<T>>
@@ -552,7 +550,7 @@ class DenseElementsAttr : public Attribute {
   template <typename T, typename = APIntValueTemplateCheckT<T>>
   iterator_range_impl<IntElementIterator> getValues() const {
     assert(getElementType().isIntOrIndex() && "expected integral type");
-    return {Attribute::getType(), raw_int_begin(), raw_int_end()};
+    return {getType(), raw_int_begin(), raw_int_end()};
   }
   template <typename T, typename = APIntValueTemplateCheckT<T>>
   IntElementIterator value_begin() const {
@@ -991,8 +989,6 @@ inline bool operator==(StringRef lhs, StringAttr rhs) {
 }
 inline bool operator!=(StringRef lhs, StringAttr rhs) { return !(lhs == rhs); }
 
-inline Type StringAttr::getType() const { return Attribute::getType(); }
-
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 67fa0a31b5670..0b620908c2069 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -64,7 +64,6 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
     AffineMap getAffineMap() const { return getValue(); }
   }];
   let skipDefaultBuilders = 1;
-  let typeBuilder = "IndexType::get($_value.getContext())";
 }
 
 //===----------------------------------------------------------------------===//
@@ -140,11 +139,11 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array", [
 }
 
 //===----------------------------------------------------------------------===//
-// DenseIntOrFPElementsAttr
+// DenseArrayBaseAttr
 //===----------------------------------------------------------------------===//
 
 def Builtin_DenseArrayBase : Builtin_Attr<
-    "DenseArrayBase", [ElementsAttrInterface]> {
+    "DenseArrayBase", [ElementsAttrInterface, TypedAttrInterface]> {
   let summary = "A dense array of i8, i16, i32, i64, f32, or f64.";
   let description = [{
     A dense array attribute is an attribute that represents a dense array of
@@ -197,8 +196,12 @@ def Builtin_DenseArrayBase : Builtin_Attr<
     const float *value_begin_impl(OverloadToken<float>) const;
     const double *value_begin_impl(OverloadToken<double>) const;
 
-    /// Methods to support type inquiry through isa, cast, and dyn_cast.
+    /// Returns the shaped type, containing the number of elements in the array
+    /// and the array element type.
+    ShapedType getType() const;
+    /// Returns the element type.
     EltType getElementType() const;
+
     /// Printer for the short form: will dispatch to the appropriate subclass.
     void print(AsmPrinter &printer) const;
     void print(raw_ostream &os) const;
@@ -216,7 +219,8 @@ def Builtin_DenseArrayBase : Builtin_Attr<
 //===----------------------------------------------------------------------===//
 
 def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
-    "DenseIntOrFPElements", [ElementsAttrInterface], "DenseElementsAttr"
+    "DenseIntOrFPElements", [ElementsAttrInterface, TypedAttrInterface],
+    "DenseElementsAttr"
   > {
   let summary = "An Attribute containing a dense multi-dimensional array of "
                 "integer or floating-point values";
@@ -355,7 +359,8 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
 //===----------------------------------------------------------------------===//
 
 def Builtin_DenseStringElementsAttr : Builtin_Attr<
-    "DenseStringElements", [ElementsAttrInterface], "DenseElementsAttr"
+    "DenseStringElements", [ElementsAttrInterface, TypedAttrInterface],
+    "DenseElementsAttr"
   > {
   let summary = "An Attribute containing a dense multi-dimensional array of "
                 "strings";
@@ -523,7 +528,7 @@ def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
 // FloatAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_FloatAttr : Builtin_Attr<"Float"> {
+def Builtin_FloatAttr : Builtin_Attr<"Float", [TypedAttrInterface]> {
   let summary = "An Attribute containing a floating-point value";
   let description = [{
     Syntax:
@@ -586,7 +591,7 @@ def Builtin_FloatAttr : Builtin_Attr<"Float"> {
 // IntegerAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_IntegerAttr : Builtin_Attr<"Integer"> {
+def Builtin_IntegerAttr : Builtin_Attr<"Integer", [TypedAttrInterface]> {
   let summary = "An Attribute containing a integer value";
   let description = [{
     Syntax:
@@ -703,7 +708,7 @@ def Builtin_IntegerSetAttr : Builtin_Attr<"IntegerSet"> {
 // OpaqueAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_OpaqueAttr : Builtin_Attr<"Opaque"> {
+def Builtin_OpaqueAttr : Builtin_Attr<"Opaque", [TypedAttrInterface]> {
   let summary = "An opaque representation of another Attribute";
   let description = [{
     Syntax:
@@ -741,7 +746,7 @@ def Builtin_OpaqueAttr : Builtin_Attr<"Opaque"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_OpaqueElementsAttr : Builtin_Attr<
-    "OpaqueElements", [ElementsAttrInterface]
+    "OpaqueElements", [ElementsAttrInterface, TypedAttrInterface]
   > {
   let summary = "An opaque representation of a multi-dimensional array";
   let description = [{
@@ -803,7 +808,7 @@ def Builtin_OpaqueElementsAttr : Builtin_Attr<
 //===----------------------------------------------------------------------===//
 
 def Builtin_SparseElementsAttr : Builtin_Attr<
-    "SparseElements", [ElementsAttrInterface]
+    "SparseElements", [ElementsAttrInterface, TypedAttrInterface]
   > {
   let summary = "An opaque representation of a multi-dimensional array";
   let description = [{
@@ -966,7 +971,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
 // StringAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_StringAttr : Builtin_Attr<"String"> {
+def Builtin_StringAttr : Builtin_Attr<"String", [TypedAttrInterface]> {
   let summary = "An Attribute containing a string";
   let description = [{
     Syntax:

diff  --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
new file mode 100644
index 0000000000000..6468ca53c35c7
--- /dev/null
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
@@ -0,0 +1,14 @@
+//===- BuiltinTypeInterfaces.h - Builtin Type Interfaces --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_BUILTINTYPEINTERFACES_H
+#define MLIR_IR_BUILTINTYPEINTERFACES_H
+
+#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
+
+#endif // MLIR_IR_BUILTINTYPEINTERFACES_H

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 4bdc10a25023f..76b3ef8fa2d93 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -9,8 +9,9 @@
 #ifndef MLIR_IR_BUILTINTYPES_H
 #define MLIR_IR_BUILTINTYPES_H
 
-#include "BuiltinAttributeInterfaces.h"
-#include "SubElementInterfaces.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/SubElementInterfaces.h"
 
 namespace llvm {
 class BitVector;
@@ -21,8 +22,6 @@ struct fltSemantics;
 // Tablegen Interface Declarations
 //===----------------------------------------------------------------------===//
 
-#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
-
 namespace mlir {
 class AffineExpr;
 class AffineMap;

diff  --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 2e488332c9596..f38ce43748a19 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -215,8 +215,9 @@ static Symbol parseExtendedSymbol(Parser &p, SymbolAliasMap &aliases,
 /// Parse an extended attribute.
 ///
 ///   extended-attribute ::= (dialect-attribute | attribute-alias)
-///   dialect-attribute  ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
-///   dialect-attribute  ::= `#` alias-name pretty-dialect-sym-body?
+///   dialect-attribute  ::= `#` dialect-namespace `<` attr-data `>`
+///                          (`:` type)?
+///                        | `#` alias-name pretty-dialect-sym-body? (`:` type)?
 ///   attribute-alias    ::= `#` alias-name
 ///
 Attribute Parser::parseExtendedAttr(Type type) {
@@ -250,9 +251,10 @@ Attribute Parser::parseExtendedAttr(Type type) {
       });
 
   // Ensure that the attribute has the same type as requested.
-  if (attr && type && attr.getType() != type) {
+  auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
+  if (type && typedAttr && typedAttr.getType() != type) {
     emitError("attribute type 
diff erent than expected: expected ")
-        << type << ", but got " << attr.getType();
+        << type << ", but got " << typedAttr.getType();
     return nullptr;
   }
   return attr;

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index da43da1af57f4..435c974e7c6e8 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -753,7 +753,10 @@ MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
 }
 
 MlirType mlirAttributeGetType(MlirAttribute attribute) {
-  return wrap(unwrap(attribute).getType());
+  Attribute attr = unwrap(attribute);
+  if (auto typedAttr = attr.dyn_cast<TypedAttr>())
+    return wrap(typedAttr.getType());
+  return wrap(NoneType::get(attr.getContext()));
 }
 
 MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {

diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index a6ef77c5cc560..6293e7448a64e 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -395,8 +395,8 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
     return failure();
 
   Attribute cstAttr = constOp.getValue();
-  if (cstAttr.getType().isa<ShapedType>())
-    cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
+  if (auto elementsAttr = cstAttr.dyn_cast<DenseElementsAttr>())
+    cstAttr = elementsAttr.getSplatValue<Attribute>();
 
   Type dstType = getTypeConverter()->convertType(srcType);
   if (!dstType)

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index ed41db5be32a9..6691193146f2f 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -698,8 +698,8 @@ static void convertConstantOp(arith::ConstantOp op,
                               llvm::DenseMap<Value, Value> &valueMapping) {
   assert(constantSupportsMMAMatrixType(op));
   OpBuilder b(op);
-  Attribute splat =
-      op.getValue().cast<SplatElementsAttr>().getSplatValue<Attribute>();
+  auto splat =
+      op.getValue().cast<SplatElementsAttr>().getSplatValue<TypedAttr>();
   auto scalarConstant =
       b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
   const char *fragType = inferFragType(op);

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 4d0d50e4d6350..537dc6cbe96ed 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -128,7 +128,8 @@ LogicalResult arith::ConstantOp::verify() {
 
 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
   // The value's type must be the same as the provided type.
-  if (value.getType() != type)
+  auto typedAttr = value.dyn_cast<TypedAttr>();
+  if (!typedAttr || typedAttr.getType() != type)
     return false;
   // Integer values must be signless.
   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())

diff  --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 445d391e65b7a..cfc2e00abdec9 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -30,11 +30,13 @@ void ConstantOp::getAsmResultNames(
 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
   if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
     auto complexTy = type.dyn_cast<ComplexType>();
-    if (!complexTy)
+    if (!complexTy || arrAttr.size() != 2)
       return false;
     auto complexEltTy = complexTy.getElementType();
-    return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
-           arrAttr[1].getType() == complexEltTy;
+    auto re = arrAttr[0].dyn_cast<FloatAttr>();
+    auto im = arrAttr[1].dyn_cast<FloatAttr>();
+    return re && im && re.getType() == complexEltTy &&
+           im.getType() == complexEltTy;
   }
   return false;
 }
@@ -48,11 +50,14 @@ LogicalResult ConstantOp::verify() {
   }
 
   auto complexEltTy = getType().getElementType();
-  if (complexEltTy != arrayAttr[0].getType() ||
-      complexEltTy != arrayAttr[1].getType()) {
+  auto re = arrayAttr[0].dyn_cast<FloatAttr>();
+  auto im = arrayAttr[1].dyn_cast<FloatAttr>();
+  if (!re || !im)
+    return emitOpError("requires attribute's elements to be float attributes");
+  if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
     return emitOpError()
-           << "requires attribute's element types (" << arrayAttr[0].getType()
-           << ", " << arrayAttr[1].getType()
+           << "requires attribute's element types (" << re.getType() << ", "
+           << im.getType()
            << ") to match the element type of the op's return type ("
            << complexEltTy << ")";
   }

diff  --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index edfcb369399e3..7f5e15a9c76e8 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -86,15 +86,17 @@ LogicalResult emitc::CallOp::verify() {
 
   if (Optional<ArrayAttr> argsAttr = getArgs()) {
     for (Attribute arg : *argsAttr) {
-      if (arg.getType().isa<IndexType>()) {
-        int64_t index = arg.cast<IntegerAttr>().getInt();
+      auto intAttr = arg.dyn_cast<IntegerAttr>();
+      if (intAttr && intAttr.getType().isa<IndexType>()) {
+        int64_t index = intAttr.getInt();
         // Args with elements of type index must be in range
         // [0..operands.size).
         if ((index < 0) || (index >= static_cast<int64_t>(getNumOperands())))
           return emitOpError("index argument is out of range");
 
         // Args with elements of type ArrayAttr must have a type.
-      } else if (arg.isa<ArrayAttr>() && arg.getType().isa<NoneType>()) {
+      } else if (arg.isa<ArrayAttr>() /*&& arg.getType().isa<NoneType>()*/) {
+        // FIXME: Array attributes never have types
         return emitOpError("array argument has no type");
       }
     }
@@ -102,8 +104,7 @@ LogicalResult emitc::CallOp::verify() {
 
   if (Optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
     for (Attribute tArg : *templateArgsAttr) {
-      if (!tArg.isa<TypeAttr>() && !tArg.isa<IntegerAttr>() &&
-          !tArg.isa<FloatAttr>() && !tArg.isa<emitc::OpaqueAttr>())
+      if (!tArg.isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>())
         return emitOpError("template argument has invalid type");
     }
   }
@@ -117,7 +118,7 @@ LogicalResult emitc::CallOp::verify() {
 
 /// The constant op requires that the attribute's type matches the return type.
 LogicalResult emitc::ConstantOp::verify() {
-  Attribute value = getValueAttr();
+  TypedAttr value = getValueAttr();
   Type type = getType();
   if (!value.getType().isa<NoneType>() && type != value.getType())
     return emitOpError() << "requires attribute's type (" << value.getType()
@@ -171,7 +172,7 @@ ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) {
 
 /// The variable op requires that the attribute's type matches the return type.
 LogicalResult emitc::VariableOp::verify() {
-  Attribute value = getValueAttr();
+  TypedAttr value = getValueAttr();
   Type type = getType();
   if (!value.getType().isa<NoneType>() && type != value.getType())
     return emitOpError() << "requires attribute's type (" << value.getType()
@@ -204,7 +205,9 @@ Attribute emitc::OpaqueAttr::parse(AsmParser &parser, Type type) {
   }
   if (parser.parseGreater())
     return Attribute();
-  return get(parser.getContext(), value);
+
+  return get(parser.getContext(),
+             type ? type : NoneType::get(parser.getContext()), value);
 }
 
 void emitc::OpaqueAttr::print(AsmPrinter &printer) const {

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e77f5a118c0e3..33fe8902b977b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2409,11 +2409,16 @@ LogicalResult LLVM::ConstantOp::verify() {
     }
 
     auto arrayAttr = getValue().dyn_cast<ArrayAttr>();
-    if (!arrayAttr || arrayAttr.size() != 2 ||
-        arrayAttr[0].getType() != arrayAttr[1].getType()) {
+    if (!arrayAttr || arrayAttr.size() != 2) {
       return emitOpError() << "expected array attribute with two elements, "
                               "representing a complex constant";
     }
+    auto re = arrayAttr[0].dyn_cast<TypedAttr>();
+    auto im = arrayAttr[1].dyn_cast<TypedAttr>();
+    if (!re || !im || re.getType() != im.getType()) {
+      return emitOpError()
+             << "expected array attribute with two elements of the same type";
+    }
 
     Type elementType = structType.getBody()[0];
     if (!elementType

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2d7c0c30f068e..a55b29cad363d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -400,8 +400,10 @@ class RegionBuilderHelper {
     OpBuilder builder = getBuilder();
     Location loc = builder.getUnknownLoc();
     Attribute valueAttr = parseAttribute(value, builder.getContext());
-    return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
-                                             valueAttr);
+    Type type = NoneType::get(builder.getContext());
+    if (auto typedAttr = valueAttr.dyn_cast<TypedAttr>())
+      type = typedAttr.getType();
+    return builder.create<arith::ConstantOp>(loc, type, valueAttr);
   }
 
   Value index(int64_t dim) {

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d8b9187bb6a3d..246a23c79bd91 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -530,7 +530,11 @@ transform::PadOp::applyToOne(linalg::LinalgOp target,
   SmallVector<Attribute> paddingValues;
   for (auto const &it :
        llvm::zip(getPaddingValues(), target->getOperandTypes())) {
-    Attribute attr = std::get<0>(it);
+    auto attr = std::get<0>(it).dyn_cast<TypedAttr>();
+    if (!attr) {
+      emitOpError("expects padding values to be typed attributes");
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
     Type elementType = getElementTypeOrSelf(std::get<1>(it));
     // Try to parse string attributes to obtain an attribute of element type.
     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index a6d5dd49b33ac..34305a0d3887f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1509,14 +1509,14 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
       return failure();
     for (OpOperand *opOperand : genericOp.getInputOperands()) {
       Operation *def = opOperand->get().getDefiningOp();
-      Attribute constantAttr;
+      TypedAttr constantAttr;
       auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
         {
           DenseElementsAttr splatAttr;
           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
               splatAttr.isSplat() &&
               splatAttr.getType().getElementType().isIntOrFloat()) {
-            constantAttr = splatAttr.getSplatValue<Attribute>();
+            constantAttr = splatAttr.getSplatValue<TypedAttr>();
             return true;
           }
         }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e7dfa725aa67c..600300855be5b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -198,8 +198,11 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
   if (opOperand->getOperandNumber() >= paddingValues.size())
     return failure();
   Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
-  Value paddingValue = b.create<arith::ConstantOp>(
-      opToPad.getLoc(), paddingAttr.getType(), paddingAttr);
+  Type paddingType = b.getType<NoneType>();
+  if (auto typedAttr = paddingAttr.dyn_cast<TypedAttr>())
+    paddingType = typedAttr.getType();
+  Value paddingValue =
+      b.create<arith::ConstantOp>(opToPad.getLoc(), paddingType, paddingAttr);
 
   // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
   OpOperand *currOpOperand = opOperand;

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 2e35cda377191..616b228051e5e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1309,8 +1309,8 @@ LogicalResult GlobalOp::verify() {
 
     // Check that the type of the initial value is compatible with the type of
     // the global variable.
-    if (initValue.isa<ElementsAttr>()) {
-      Type initType = initValue.getType();
+    if (auto elementsAttr = initValue.dyn_cast<ElementsAttr>()) {
+      Type initType = elementsAttr.getType();
       Type tensorType = getTensorTypeFromMemRefType(memrefType);
       if (initType != tensorType)
         return emitOpError("initial value expected to be of type ")

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9ca9c15e119fc..4396c96766880 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -28,20 +28,15 @@ using namespace mlir;
 
 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
 /// or splat vector bool constant.
-static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
-  if (!boolAttr)
+static Optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
+  if (!attr)
     return llvm::None;
 
-  auto type = boolAttr.getType();
-  if (type.isInteger(1)) {
-    auto attr = boolAttr.cast<BoolAttr>();
-    return attr.getValue();
-  }
-  if (auto vecType = type.cast<VectorType>()) {
-    if (vecType.getElementType().isInteger(1))
-      if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
-        return attr.getSplatValue<bool>();
-  }
+  if (auto boolAttr = attr.dyn_cast<BoolAttr>())
+    return boolAttr.getValue();
+  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>())
+    if (splatAttr.getElementType().isInteger(1))
+      return splatAttr.getSplatValue<bool>();
   return llvm::None;
 }
 

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index e79d1d2c220f6..ac78e08d56097 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1803,7 +1803,9 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
   if (parser.parseAttribute(value, kValueAttrName, state.attributes))
     return failure();
 
-  Type type = value.getType();
+  Type type = NoneType::get(parser.getContext());
+  if (auto typedAttr = value.dyn_cast<TypedAttr>())
+    type = typedAttr.getType();
   if (type.isa<NoneType, TensorType>()) {
     if (parser.parseColonType(type))
       return failure();
@@ -1820,15 +1822,15 @@ void spirv::ConstantOp::print(OpAsmPrinter &printer) {
 
 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
                                         Type opType) {
-  auto valueType = value.getType();
-
   if (value.isa<IntegerAttr, FloatAttr>()) {
+    auto valueType = value.cast<TypedAttr>().getType();
     if (valueType != opType)
       return op.emitOpError("result type (")
              << opType << ") does not match value type (" << valueType << ")";
     return success();
   }
   if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
+    auto valueType = value.cast<TypedAttr>().getType();
     if (valueType == opType)
       return success();
     auto arrayType = opType.dyn_cast<spirv::ArrayType>();
@@ -1873,7 +1875,7 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
     }
     return success();
   }
-  return op.emitOpError("cannot have value of type ") << valueType;
+  return op.emitOpError("cannot have attribute: ") << value;
 }
 
 LogicalResult spirv::ConstantOp::verify() {

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 764ecc4eb53cb..c78cc1a459ab9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1737,7 +1737,7 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
   if (!operands[0])
     return {};
   auto vectorType = getVectorType();
-  if (operands[0].getType().isIntOrIndexOrFloat())
+  if (operands[0].isa<IntegerAttr, FloatAttr>())
     return DenseElementsAttr::get(vectorType, operands[0]);
   if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
     return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
@@ -1855,7 +1855,7 @@ OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
   if (!lhs || !rhs)
     return {};
 
-  auto lhsType = lhs.getType().cast<VectorType>();
+  auto lhsType = lhs.cast<DenseElementsAttr>().getType().cast<VectorType>();
   // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
   // manipulation.
   if (lhsType.getRank() != 1)

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 2dc324f455bed..e5fd5eaa2d0a7 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1752,7 +1752,6 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
   if (succeeded(printAlias(attr)))
     return;
 
-  auto attrType = attr.getType();
   if (!isa<BuiltinDialect>(attr.getDialect())) {
     printDialectAttribute(attr);
   } else if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
@@ -1768,7 +1767,8 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
     os << '}';
 
   } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
-    if (attrType.isSignlessInteger(1)) {
+    Type intType = intAttr.getType();
+    if (intType.isSignlessInteger(1)) {
       os << (intAttr.getValue().getBoolValue() ? "true" : "false");
 
       // Boolean integer attributes always elides the type.
@@ -1779,18 +1779,18 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
     // signless 1-bit values.  Indexes, signed values, and multi-bit signless
     // values print as signed.
     bool isUnsigned =
-        attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
+        intType.isUnsignedInteger() || intType.isSignlessInteger(1);
     intAttr.getValue().print(os, !isUnsigned);
 
     // IntegerAttr elides the type if I64.
-    if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
+    if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(64))
       return;
 
   } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
     printFloatValue(floatAttr.getValue(), os);
 
     // FloatAttr elides the type if F64.
-    if (typeElision == AttrTypeElision::May && attrType.isF64())
+    if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64())
       return;
 
   } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
@@ -1892,7 +1892,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
       os << "[:f64";
       break;
     }
-    if (denseArrayAttr.getType().cast<ShapedType>().getRank())
+    if (denseArrayAttr.getType().getRank())
       os << " ";
     denseArrayAttr.printWithoutBraces(os);
     os << "]";
@@ -1902,9 +1902,14 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
     llvm::report_fatal_error("Unknown builtin attribute");
   }
   // Don't print the type if we must elide it, or if it is a None type.
-  if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
-    os << " : ";
-    printType(attrType);
+  if (typeElision != AttrTypeElision::Must) {
+    if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
+      Type attrType = typedAttr.getType();
+      if (!attrType.isa<NoneType>()) {
+        os << " : ";
+        printType(attrType);
+      }
+    }
   }
 }
 

diff  --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 747d7275d75c8..ced9dcf6c7b7c 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -43,9 +43,10 @@ inline size_t getDenseElementBitWidth(Type eltType) {
 /// An attribute representing a reference to a dense vector or tensor object.
 struct DenseElementsAttributeStorage : public AttributeStorage {
 public:
-  DenseElementsAttributeStorage(ShapedType ty, bool isSplat)
-      : AttributeStorage(ty), isSplat(isSplat) {}
+  DenseElementsAttributeStorage(ShapedType type, bool isSplat)
+      : type(type), isSplat(isSplat) {}
 
+  ShapedType type;
   bool isSplat;
 };
 
@@ -75,7 +76,7 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
 
   /// Compare this storage instance with the provided key.
   bool operator==(const KeyTy &key) const {
-    if (key.type != getType())
+    if (key.type != type)
       return false;
 
     // For boolean splats we need to explicitly check that the first bit is the
@@ -228,7 +229,7 @@ struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
 
   /// Compare this storage instance with the provided key.
   bool operator==(const KeyTy &key) const {
-    if (key.type != getType())
+    if (key.type != type)
       return false;
 
     // Otherwise, we can default to just checking the data. StringRefs compare
@@ -324,12 +325,12 @@ struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
 
 struct StringAttrStorage : public AttributeStorage {
   StringAttrStorage(StringRef value, Type type)
-      : AttributeStorage(type), value(value), referencedDialect(nullptr) {}
+      : type(type), value(value), referencedDialect(nullptr) {}
 
   /// The hash key is a tuple of the parameter types.
   using KeyTy = std::pair<StringRef, Type>;
   bool operator==(const KeyTy &key) const {
-    return value == key.first && getType() == key.second;
+    return value == key.first && type == key.second;
   }
   static ::llvm::hash_code hashKey(const KeyTy &key) {
     return DenseMapInfo<KeyTy>::getHashValue(key);
@@ -346,6 +347,8 @@ struct StringAttrStorage : public AttributeStorage {
   /// Initialize the storage given an MLIRContext.
   void initialize(MLIRContext *context);
 
+  /// The type of the string.
+  Type type;
   /// The raw string value.
   StringRef value;
   /// If the string value contains a dialect namespace prefix (e.g.

diff  --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
index fd289917c64cc..6e35f120e22c4 100644
--- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
@@ -24,16 +24,12 @@ using namespace mlir::detail;
 // ElementsAttr
 //===----------------------------------------------------------------------===//
 
-ShapedType ElementsAttr::getType() const {
-  return Attribute::getType().cast<ShapedType>();
+Type ElementsAttr::getElementType(ElementsAttr elementsAttr) {
+  return elementsAttr.getType().getElementType();
 }
 
-Type ElementsAttr::getElementType(Attribute elementsAttr) {
-  return elementsAttr.getType().cast<ShapedType>().getElementType();
-}
-
-int64_t ElementsAttr::getNumElements(Attribute elementsAttr) {
-  return elementsAttr.getType().cast<ShapedType>().getNumElements();
+int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) {
+  return elementsAttr.getType().getNumElements();
 }
 
 bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
@@ -51,9 +47,9 @@ bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
     return 0 <= dim && dim < shape[i];
   });
 }
-bool ElementsAttr::isValidIndex(Attribute elementsAttr,
+bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
                                 ArrayRef<uint64_t> index) {
-  return isValidIndex(elementsAttr.getType().cast<ShapedType>(), index);
+  return isValidIndex(elementsAttr.getType(), index);
 }
 
 uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 5b282bf868b59..021da17b3c334 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -261,6 +261,8 @@ StringAttr StringAttr::get(const Twine &twine, Type type) {
 
 StringRef StringAttr::getValue() const { return getImpl()->value; }
 
+Type StringAttr::getType() const { return getImpl()->type; }
+
 Dialect *StringAttr::getReferencedDialect() const {
   return getImpl()->referencedDialect;
 }
@@ -688,29 +690,28 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
 /// Custom storage to ensure proper memory alignment for the allocation of
 /// DenseArray of any element type.
 struct mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage {
-  using KeyTy = std::tuple<ShapedType, DenseArrayBaseAttr::EltType,
-                           ::llvm::ArrayRef<char>>;
+  using KeyTy =
+      std::tuple<ShapedType, DenseArrayBaseAttr::EltType, ArrayRef<char>>;
   DenseArrayBaseAttrStorage(ShapedType type,
                             DenseArrayBaseAttr::EltType eltType,
-                            ::llvm::ArrayRef<char> elements)
-      : AttributeStorage(type), eltType(eltType), elements(elements) {}
+                            ArrayRef<char> elements)
+      : type(type), eltType(eltType), elements(elements) {}
 
-  bool operator==(const KeyTy &tblgenKey) const {
-    return (getType() == std::get<0>(tblgenKey)) &&
-           (eltType == std::get<1>(tblgenKey)) &&
-           (elements == std::get<2>(tblgenKey));
+  bool operator==(const KeyTy &key) const {
+    return (type == std::get<0>(key)) && (eltType == std::get<1>(key)) &&
+           (elements == std::get<2>(key));
   }
 
-  static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {
-    return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey),
-                                std::get<2>(tblgenKey));
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
+                              std::get<2>(key));
   }
 
   static DenseArrayBaseAttrStorage *
-  construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey) {
-    auto type = std::get<0>(tblgenKey);
-    auto eltType = std::get<1>(tblgenKey);
-    auto elements = std::get<2>(tblgenKey);
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    auto type = std::get<0>(key);
+    auto eltType = std::get<1>(key);
+    auto elements = std::get<2>(key);
     if (!elements.empty()) {
       char *alloc = static_cast<char *>(
           allocator.allocate(elements.size(), alignof(uint64_t)));
@@ -721,14 +722,17 @@ struct mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage {
         DenseArrayBaseAttrStorage(type, eltType, elements);
   }
 
+  ShapedType type;
   DenseArrayBaseAttr::EltType eltType;
-  ::llvm::ArrayRef<char> elements;
+  ArrayRef<char> elements;
 };
 
 DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
   return getImpl()->eltType;
 }
 
+ShapedType DenseArrayBaseAttr::getType() const { return getImpl()->type; }
+
 const int8_t *
 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
   return cast<DenseI8ArrayAttr>().asArrayRef().begin();
@@ -974,8 +978,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
 
   // If the element type is not based on int/float/index, assume it is a string
   // type.
-  auto eltType = type.getElementType();
-  if (!type.getElementType().isIntOrIndexOrFloat()) {
+  Type eltType = type.getElementType();
+  if (!eltType.isIntOrIndexOrFloat()) {
     SmallVector<StringRef, 8> stringValues;
     stringValues.reserve(values.size());
     for (Attribute attr : values) {
@@ -995,14 +999,16 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
       llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT));
   APInt intVal;
   for (unsigned i = 0, e = values.size(); i < e; ++i) {
-    assert(eltType == values[i].getType() &&
-           "expected attribute value to have element type");
-    if (eltType.isa<FloatType>())
-      intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
-    else if (eltType.isa<IntegerType, IndexType>())
-      intVal = values[i].cast<IntegerAttr>().getValue();
-    else
-      llvm_unreachable("unexpected element type");
+    if (auto floatAttr = values[i].dyn_cast<FloatAttr>()) {
+      assert(floatAttr.getType() == eltType &&
+             "expected float attribute type to equal element type");
+      intVal = floatAttr.getValue().bitcastToAPInt();
+    } else {
+      auto intAttr = values[i].cast<IntegerAttr>();
+      assert(intAttr.getType() == eltType &&
+             "expected integer attribute type to equal element type");
+      intVal = intAttr.getValue();
+    }
 
     assert(intVal.getBitWidth() == bitWidth &&
            "expected value to have same bitwidth as element type");
@@ -1010,7 +1016,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   }
 
   // Handle the special encoding of splat of bool.
-  if (values.size() == 1 && values[0].getType().isInteger(1))
+  if (values.size() == 1 && eltType.isInteger(1))
     data[0] = data[0] ? -1 : 0;
 
   return DenseIntOrFPElementsAttr::getRaw(type, data);
@@ -1326,7 +1332,7 @@ DenseElementsAttr DenseElementsAttr::mapValues(
 }
 
 ShapedType DenseElementsAttr::getType() const {
-  return Attribute::getType().cast<ShapedType>();
+  return static_cast<const DenseElementsAttributeStorage *>(impl)->type;
 }
 
 Type DenseElementsAttr::getElementType() const {
@@ -1546,8 +1552,9 @@ DenseElementsAttr DenseFPElementsAttr::mapValues(
 
 /// Method for supporting type inquiry through isa, cast and dyn_cast.
 bool DenseFPElementsAttr::classof(Attribute attr) {
-  return attr.isa<DenseElementsAttr>() &&
-         attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
+  if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>())
+    return denseAttr.getType().getElementType().isa<FloatType>();
+  return false;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1564,8 +1571,9 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
 
 /// Method for supporting type inquiry through isa, cast and dyn_cast.
 bool DenseIntElementsAttr::classof(Attribute attr) {
-  return attr.isa<DenseElementsAttr>() &&
-         attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
+  if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>())
+    return denseAttr.getType().getElementType().isIntOrIndex();
+  return false;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 273faa89b826c..2e00bd4778c30 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -896,10 +896,6 @@ void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
                                                   MLIRContext *ctx,
                                                   TypeID attrID) {
   storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx));
-
-  // If the attribute did not provide a type, then default to NoneType.
-  if (!storage->getType())
-    storage->setType(NoneType::get(ctx));
 }
 
 BoolAttr BoolAttr::get(MLIRContext *context, bool value) {

diff  --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp
index e2cee18c418d6..bcf698400b1d0 100644
--- a/mlir/lib/IR/TypeUtilities.cpp
+++ b/mlir/lib/IR/TypeUtilities.cpp
@@ -32,7 +32,9 @@ Type mlir::getElementTypeOrSelf(Value val) {
 }
 
 Type mlir::getElementTypeOrSelf(Attribute attr) {
-  return getElementTypeOrSelf(attr.getType());
+  if (auto typedAttr = attr.dyn_cast<TypedAttr>())
+    return getElementTypeOrSelf(typedAttr.getType());
+  return {};
 }
 
 SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {

diff  --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 7125dd34a2fae..d4ca8b25eaa13 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1652,7 +1652,9 @@ void ByteCodeExecutor::executeGetAttributeType() {
   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
   unsigned memIndex = read();
   Attribute attr = read<Attribute>();
-  Type type = attr ? attr.getType() : Type();
+  Type type;
+  if (auto typedAttr = attr.dyn_cast<TypedAttr>())
+    type = typedAttr.getType();
 
   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
                           << "  * Result: " << type << "\n");

diff  --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 37760d4b5d9ac..bf86218c360b8 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -283,7 +283,8 @@ bool AttrOrTypeParameter::isOptional() const {
 }
 
 Optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
-  return getDefValue<llvm::StringInit>("defaultValue");
+  Optional<StringRef> result = getDefValue<llvm::StringInit>("defaultValue");
+  return result && !result->empty() ? result : llvm::None;
 }
 
 llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }

diff  --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 3b3fba1957d56..53323ee309365 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -823,7 +823,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
   if (auto type = attr.dyn_cast<TypeAttr>())
     return emitType(loc, type.getValue());
 
-  return emitError(loc, "cannot emit attribute of type ") << attr.getType();
+  return emitError(loc, "cannot emit attribute: ") << attr;
 }
 
 LogicalResult CppEmitter::emitOperands(Operation &op) {

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 942650b0caa41..47542a41bbb71 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -769,7 +769,8 @@ uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
 
   // Process the type for this bool literal
   uint32_t typeID = 0;
-  if (failed(processType(loc, boolAttr.getType(), typeID))) {
+  if (failed(
+          processType(loc, boolAttr.cast<IntegerAttr>().getType(), typeID))) {
     return 0;
   }
 

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index e5e15be3b90d1..459d6b1887536 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -332,7 +332,7 @@ llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
 // -----
 
 llvm.func @array_attribute_two_
diff erent_types() -> !llvm.struct<(f64, f64)> {
-  // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
+  // expected-error @+1 {{expected array attribute with two elements of the same type}}
   %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
@@ -547,7 +547,7 @@ func.func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16,
                          %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
                          %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
   // expected-error at +1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}}
-  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] 
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
     {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
@@ -571,7 +571,7 @@ func.func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                          %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
                          %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
   // expected-error at +1 {{op requires attribute 'layoutA'}}
-  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] 
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
     {shape = #nvvm.shape<m = 8, n = 8, k = 4>}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
@@ -594,7 +594,7 @@ func.func @nvvm_invalid_mma_8(%a0 : i32, %a1 : i32,
   // expected-error at +1 {{op requires b1Op attribute}}
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
     {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,     
+     multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
      shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index ef1a53b3ab3fb..8c64e4570dece 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -79,7 +79,7 @@ func.func @const() -> () {
 // -----
 
 func.func @unaccepted_std_attr() -> () {
-  // expected-error @+1 {{cannot have value of type 'none'}}
+  // expected-error @+1 {{cannot have attribute: unit}}
   %0 = spv.Constant unit : none
   return
 }

diff  --git a/mlir/test/IR/file-metadata-resources.mlir b/mlir/test/IR/file-metadata-resources.mlir
index 79624791e3ac1..57562555c9643 100644
--- a/mlir/test/IR/file-metadata-resources.mlir
+++ b/mlir/test/IR/file-metadata-resources.mlir
@@ -5,7 +5,7 @@
 // CHECK-NEXT:   blob1: "0x08000000010000000000000002000000000000000300000000000000"
 // CHECK-NEXT: }
 
-module attributes { test.blob_ref = #test.e1di64_elements<blob1> } {}
+module attributes { test.blob_ref = #test.e1di64_elements<blob1> : tensor<*xi1>} {}
 
 {-#
   dialect_resources: {

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index fb99d6f7c9b54..c86f29cf15a96 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -53,18 +53,22 @@ def CompoundAttrNested : Test_Attr<"CompoundAttrNested"> {
 }
 
 // An attribute testing AttributeSelfTypeParameter.
-def AttrWithSelfTypeParam : Test_Attr<"AttrWithSelfTypeParam"> {
+def AttrWithSelfTypeParam
+    : Test_Attr<"AttrWithSelfTypeParam", [TypedAttrInterface]> {
   let mnemonic = "attr_with_self_type_param";
   let parameters = (ins AttributeSelfTypeParameter<"">:$type);
   let assemblyFormat = "";
 }
 
 // An attribute testing AttributeSelfTypeParameter.
-def AttrWithTypeBuilder : Test_Attr<"AttrWithTypeBuilder"> {
+def AttrWithTypeBuilder
+    : Test_Attr<"AttrWithTypeBuilder", [TypedAttrInterface]> {
   let mnemonic = "attr_with_type_builder";
-  let parameters = (ins "::mlir::IntegerAttr":$attr);
-  let typeBuilder = "$_attr.getType()";
-  let hasCustomAssemblyFormat = 1;
+  let parameters = (ins
+    "::mlir::IntegerAttr":$attr,
+    AttributeSelfTypeParameter<"", "mlir::Type", "$attr.getType()">:$type
+  );
+  let assemblyFormat = "$attr";
 }
 
 def TestAttrTrait : NativeAttrTrait<"TestAttrTrait">;
@@ -76,7 +80,7 @@ def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
 
 // Test support for ElementsAttrInterface.
 def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
-    ElementsAttrInterface
+    ElementsAttrInterface, TypedAttrInterface
   ]> {
   let mnemonic = "i64_elements";
   let parameters = (ins
@@ -215,7 +219,7 @@ def TestAttrWithTypeParam : Test_Attr<"TestAttrWithTypeParam"> {
 
 // Test self type parameter with assembly format.
 def TestAttrSelfTypeParameterFormat
-    : Test_Attr<"TestAttrSelfTypeParameterFormat"> {
+    : Test_Attr<"TestAttrSelfTypeParameterFormat", [TypedAttrInterface]> {
   let parameters = (ins "int":$a, AttributeSelfTypeParameter<"">:$type);
 
   let mnemonic = "attr_self_type_format";
@@ -237,7 +241,7 @@ def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
 
 // Test simple extern 1D vector using ElementsAttrInterface.
 def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
-    ElementsAttrInterface
+    ElementsAttrInterface, TypedAttrInterface
   ]> {
   let mnemonic = "e1di64_elements";
   let parameters = (ins

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 810380cf9ff19..1fbd2920a0b2a 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -27,21 +27,6 @@
 using namespace mlir;
 using namespace test;
 
-//===----------------------------------------------------------------------===//
-// AttrWithTypeBuilderAttr
-//===----------------------------------------------------------------------===//
-
-Attribute AttrWithTypeBuilderAttr::parse(AsmParser &parser, Type type) {
-  IntegerAttr element;
-  if (parser.parseAttribute(element))
-    return Attribute();
-  return get(parser.getContext(), element);
-}
-
-void AttrWithTypeBuilderAttr::print(AsmPrinter &printer) const {
-  printer << " " << getAttr();
-}
-
 //===----------------------------------------------------------------------===//
 // CompoundAAttr
 //===----------------------------------------------------------------------===//
@@ -114,10 +99,11 @@ TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
-LogicalResult TestAttrWithFormatAttr::verify(
-    function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two,
-    IntegerAttr three, ArrayRef<int> four,
-    ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrWithTypeBuilderAttr) {
+LogicalResult
+TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                               int64_t one, std::string two, IntegerAttr three,
+                               ArrayRef<int> four,
+                               ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrs) {
   if (four.size() != static_cast<unsigned>(one))
     return emitError() << "expected 'one' to equal 'four.size()'";
   return success();

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 57a5376285138..3b89b188da49f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -554,7 +554,7 @@ def OperandsHaveSameType :
 def ResultHasSameTypeAsAttr :
     TEST_Op<"result_has_same_type_as_attr",
             [AllTypesMatch<["attr", "result"]>]> {
-  let arguments = (ins AnyAttr:$attr);
+  let arguments = (ins TypedAttrInterface:$attr);
   let results = (outs AnyType:$result);
   let assemblyFormat = "$attr `->` type($result) attr-dict";
 }
@@ -2310,7 +2310,7 @@ def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [
 def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [
     AllTypesMatch<["value1", "value2", "result"]>
   ]> {
-  let arguments = (ins AnyAttr:$value1, AnyType:$value2);
+  let arguments = (ins TypedAttrInterface:$value1, AnyType:$value2);
   let results = (outs AnyType:$result);
   let assemblyFormat = "attr-dict $value1 `,` $value2";
 }
@@ -2338,7 +2338,7 @@ def FormatTypesMatchVariadicOp : TEST_Op<"format_types_match_variadic", [
 def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [
     TypesMatchWith<"result type matches constant", "value", "result", "$_self">
   ]> {
-  let arguments = (ins AnyAttr:$value);
+  let arguments = (ins TypedAttrInterface:$value);
   let results = (outs AnyType:$result);
   let assemblyFormat = "attr-dict $value";
 }

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 05b48671b5bae..960c3870a9c0d 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -164,9 +164,13 @@ def AttrC : TestAttr<"TestF"> {
 
 /// Test attribute with self type parameter
 
-// ATTR: TestGAttr::parse
-// ATTR:   return TestGAttr::get
-// ATTR:     odsType
+// ATTR-LABEL: Attribute TestGAttr::parse
+// ATTR: if (odsType)
+// ATTR:   if (auto reqType = odsType.dyn_cast<::mlir::Type>())
+// ATTR:     _result_type = reqType
+// ATTR: TestGAttr::get
+// ATTR-NEXT: *_result_a
+// ATTR-NEXT: _result_type.value_or(::mlir::NoneType::get(
 def AttrD : TestAttr<"TestG"> {
   let parameters = (ins "int":$a, AttributeSelfTypeParameter<"">:$type);
   let mnemonic = "attr_d";

diff  --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 7fc186362df22..c5ef65124498b 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -77,11 +77,12 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
 // DECL: int getWidthOfSomething() const;
 // DECL: ::test::SimpleTypeA getExampleTdType() const;
 // DECL: ::llvm::APFloat getApFloat() const;
+// DECL: ::mlir::Type getInner() const;
 
 // Check that AttributeSelfTypeParameter is handled properly.
 // DEF-LABEL: struct CompoundAAttrStorage
 // DEF: CompoundAAttrStorage(
-// DEF-SAME: : ::mlir::AttributeStorage(inner),
+// DEF-SAME: inner(inner)
 
 // DEF: bool operator==(const KeyTy &tblgenKey) const {
 // DEF-NEXT: return
@@ -89,14 +90,14 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
 // DEF-SAME: (exampleTdType == std::get<1>(tblgenKey)) &&
 // DEF-SAME: (apFloat.bitwiseIsEqual(std::get<2>(tblgenKey))) &&
 // DEF-SAME: (dims == std::get<3>(tblgenKey)) &&
-// DEF-SAME: (getType() == std::get<4>(tblgenKey));
+// DEF-SAME: (inner == std::get<4>(tblgenKey));
 
 // DEF: static CompoundAAttrStorage *construct
 // DEF: return new (allocator.allocate<CompoundAAttrStorage>())
 // DEF-SAME: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner);
 
 // DEF: ::mlir::Type CompoundAAttr::getInner() const {
-// DEF-NEXT: return getImpl()->getType().cast<::mlir::Type>();
+// DEF-NEXT: return getImpl()->inner;
 }
 
 def C_IndexAttr : TestAttr<"Index"> {
@@ -127,18 +128,6 @@ def D_SingleParameterAttr : TestAttr<"SingleParameter"> {
 // DECL-SAME:  detail::SingleParameterAttrStorage
 }
 
-// An attribute testing AttributeSelfTypeParameter.
-def E_AttrWithTypeBuilder : TestAttr<"AttrWithTypeBuilder"> {
-  let mnemonic = "attr_with_type_builder";
-  let parameters = (ins "::mlir::IntegerAttr":$attr);
-  let typeBuilder = "$_attr.getType()";
-  let hasCustomAssemblyFormat = 1;
-}
-
-// DEF-LABEL: struct AttrWithTypeBuilderAttrStorage
-// DEF: AttrWithTypeBuilderAttrStorage(::mlir::IntegerAttr attr)
-// DEF-SAME: : ::mlir::AttributeStorage(attr.getType()), attr(attr)
-
 def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> {
   let parameters = (ins AttrParameter<"std::string", "", "StringRef">:$param);
 }

diff  --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index 3791d82257f3b..e3af519662cae 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -68,7 +68,7 @@ def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
 
 // CHECK-LABEL: OpE definitions
 // CHECK: void OpE::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
-// CHECK: odsState.addTypes({attr.getValue().getType()});
+// CHECK: odsState.addTypes({attr.getValue().cast<::mlir::TypedAttr>().getType()});
 
 def OpF : NS_Op<"one_variadic_result_op", []> {
   let results = (outs Variadic<I32>:$x);
@@ -155,5 +155,5 @@ def OpL3 : NS_Op<"op_with_all_types_constraint",
 
 // CHECK-LABEL: LogicalResult OpL3::inferReturnTypes
 // CHECK-NOT: }
-// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").getType();
+// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").cast<::mlir::TypedAttr>().getType();
 // CHECK: inferredReturnTypes[0] = odsInferredType0;

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 6e9489ebc47c8..05f73c49bd9f4 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -295,11 +295,7 @@ void DefGen::emitAccessors() {
     // class. Otherwise, let the user define the exact accessor definition.
     if (!def.genStorageClass())
       continue;
-    auto scope = m->body().indent().scope("return getImpl()->", ";");
-    if (isa<AttributeSelfTypeParameter>(param))
-      m->body() << formatv("getType().cast<{0}>()", param.getCppType());
-    else
-      m->body() << param.getName();
+    m->body().indent() << "return getImpl()->" << param.getName() << ";";
   }
 }
 
@@ -450,37 +446,8 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) {
 void DefGen::emitStorageConstructor() {
   Constructor *ctor =
       storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
-  if (auto *attrDef = dyn_cast<AttrDef>(&def)) {
-    // For attributes, a parameter marked with AttributeSelfTypeParameter is
-    // the type initializer that must be passed to the parent constructor.
-    const auto isSelfType = [](const AttrOrTypeParameter &param) {
-      return isa<AttributeSelfTypeParameter>(param);
-    };
-    auto *selfTypeParam = llvm::find_if(params, isSelfType);
-    if (std::count_if(selfTypeParam, params.end(), isSelfType) > 1) {
-      PrintFatalError(def.getLoc(),
-                      "Only one attribute parameter can be marked as "
-                      "AttributeSelfTypeParameter");
-    }
-    // Alternatively, if a type builder was specified, use that instead.
-    std::string attrStorageInit =
-        selfTypeParam == params.end() ? "" : selfTypeParam->getName().str();
-    if (attrDef->getTypeBuilder()) {
-      FmtContext ctx;
-      for (auto &param : params)
-        ctx.addSubst(strfmt("_{0}", param.getName()), param.getName());
-      attrStorageInit = tgfmt(*attrDef->getTypeBuilder(), &ctx);
-    }
-    ctor->addMemberInitializer("::mlir::AttributeStorage",
-                               std::move(attrStorageInit));
-    // Initialize members that aren't the attribute's type.
-    for (auto &param : params)
-      if (selfTypeParam == params.end() || *selfTypeParam != param)
-        ctor->addMemberInitializer(param.getName(), param.getName());
-  } else {
-    for (auto &param : params)
-      ctor->addMemberInitializer(param.getName(), param.getName());
-  }
+  for (auto &param : params)
+    ctor->addMemberInitializer(param.getName(), param.getName());
 }
 
 void DefGen::emitKeyType() {
@@ -498,9 +465,7 @@ void DefGen::emitEquals() {
   auto &body = eq->body().indent();
   auto scope = body.scope("return (", ");");
   const auto eachFn = [&](auto it) {
-    FmtContext ctx({{"_lhs", isa<AttributeSelfTypeParameter>(it.value())
-                                 ? "getType()"
-                                 : it.value().getName()},
+    FmtContext ctx({{"_lhs", it.value().getName()},
                     {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
     body << tgfmt(it.value().getComparator(), &ctx);
   };
@@ -566,8 +531,7 @@ void DefGen::emitStorageClass() {
   // Emit the storage class members as public, at the very end of the struct.
   storageCls->finalize();
   for (auto &param : params)
-    if (!isa<AttributeSelfTypeParameter>(param))
-      storageCls->declare<Field>(param.getCppType(), param.getName());
+    storageCls->declare<Field>(param.getCppType(), param.getName());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 8321861738d08..752556fc129cb 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -246,6 +246,39 @@ class DefFormat {
 // ParserGen
 //===----------------------------------------------------------------------===//
 
+/// Generate a special-case "parser" for an attribute's self type parameter. The
+/// self type parameter has special handling in the assembly format in that it
+/// is derived from the optional trailing colon type after the attribute.
+static void genAttrSelfTypeParser(MethodBody &os, const FmtContext &ctx,
+                                  const AttributeSelfTypeParameter &param) {
+  // "Parser" for an attribute self type parameter that checks the
+  // optionally-parsed trailing colon type.
+  //
+  // $0: The C++ storage class of the type parameter.
+  // $1: The self type parameter name.
+  const char *const selfTypeParser = R"(
+if ($_type) {
+  if (auto reqType = $_type.dyn_cast<$0>()) {
+    _result_$1 = reqType;
+  } else {
+    $_parser.emitError($_loc, "invalid kind of type specified");
+    return {};
+  }
+})";
+
+  // If the attribute self type parameter is required, emit code that emits an
+  // error if the trailing type was not parsed.
+  const char *const selfTypeRequired = R"( else {
+  $_parser.emitError($_loc, "expected a trailing type");
+  return {};
+})";
+
+  os << tgfmt(selfTypeParser, &ctx, param.getCppStorageType(), param.getName());
+  if (!param.isOptional())
+    os << tgfmt(selfTypeRequired, &ctx);
+  os << "\n";
+}
+
 void DefFormat::genParser(MethodBody &os) {
   FmtContext ctx;
   ctx.addSubst("_parser", "odsParser");
@@ -262,8 +295,6 @@ void DefFormat::genParser(MethodBody &os) {
   // a loop (parsers return FailureOr anyways).
   ArrayRef<AttrOrTypeParameter> params = def.getParameters();
   for (const AttrOrTypeParameter &param : params) {
-    if (isa<AttributeSelfTypeParameter>(param))
-      continue;
     os << formatv("::mlir::FailureOr<{0}> _result_{1};\n",
                   param.getCppStorageType(), param.getName());
   }
@@ -281,7 +312,9 @@ void DefFormat::genParser(MethodBody &os) {
   // Emit an assert for each mandatory parameter. Triggering an assert means
   // the generated parser is incorrect (i.e. there is a bug in this code).
   for (const AttrOrTypeParameter &param : params) {
-    if (param.isOptional() || isa<AttributeSelfTypeParameter>(param))
+    if (auto *selfTypeParam = dyn_cast<AttributeSelfTypeParameter>(&param))
+      genAttrSelfTypeParser(os, ctx, *selfTypeParam);
+    if (param.isOptional())
       continue;
     os << formatv("assert(::mlir::succeeded(_result_{0}));\n", param.getName());
   }
@@ -306,11 +339,10 @@ void DefFormat::genParser(MethodBody &os) {
       else
         selfOs << param.getCppStorageType() << "()";
       selfOs << "))";
-    } else if (isa<AttributeSelfTypeParameter>(param)) {
-      selfOs << tgfmt("$_type", &ctx);
     } else {
       selfOs << formatv("(*_result_{0})", param.getName());
     }
+    ctx.addSubst(param.getName(), selfOs.str());
     os << param.getCppType() << "("
        << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()))
        << ")";

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 3330fdf3c28a4..0b03c76b8ba59 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -578,7 +578,8 @@ static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
   // Populate substitutions for attributes.
   auto &op = emitHelper.getOp();
   for (const auto &namedAttr : op.getAttributes())
-    ctx.addSubst(namedAttr.name, emitHelper.getAttr(namedAttr.name).str());
+    ctx.addSubst(namedAttr.name,
+                 emitHelper.getOp().getGetterName(namedAttr.name) + "()");
 
   // Populate substitutions for named operands.
   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
@@ -1756,7 +1757,7 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
   if (namedAttr.attr.isTypeAttr()) {
     resultType = "attr.getValue().cast<::mlir::TypeAttr>().getValue()";
   } else {
-    resultType = "attr.getValue().getType()";
+    resultType = "attr.getValue().cast<::mlir::TypedAttr>().getType()";
   }
 
   // Operands
@@ -2416,7 +2417,8 @@ void OpEmitter::genTypeInterfaceMethods() {
     } else {
       auto *attr =
           op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
-      body << "attributes.get(\"" << attr->name << "\").getType()";
+      body << "attributes.get(\"" << attr->name
+           << "\").cast<::mlir::TypedAttr>().getType()";
     }
     body << ";\n";
   }

diff  --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index b46fda38fbc60..82a1bcd5d1735 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -237,16 +237,19 @@ TEST(SparseElementsAttrTest, GetZero) {
 
   // Only index (0, 0) contains an element, others are supposed to return
   // the zero/empty value.
-  auto zeroIntValue = sparseInt.getValues<Attribute>()[{1, 1}];
-  EXPECT_EQ(zeroIntValue.cast<IntegerAttr>().getInt(), 0);
+  auto zeroIntValue =
+      sparseInt.getValues<Attribute>()[{1, 1}].cast<IntegerAttr>();
+  EXPECT_EQ(zeroIntValue.getInt(), 0);
   EXPECT_TRUE(zeroIntValue.getType() == intTy);
 
-  auto zeroFloatValue = sparseFloat.getValues<Attribute>()[{1, 1}];
-  EXPECT_EQ(zeroFloatValue.cast<FloatAttr>().getValueAsDouble(), 0.0f);
+  auto zeroFloatValue =
+      sparseFloat.getValues<Attribute>()[{1, 1}].cast<FloatAttr>();
+  EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
   EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
 
-  auto zeroStringValue = sparseString.getValues<Attribute>()[{1, 1}];
-  EXPECT_TRUE(zeroStringValue.cast<StringAttr>().getValue().empty());
+  auto zeroStringValue =
+      sparseString.getValues<Attribute>()[{1, 1}].cast<StringAttr>();
+  EXPECT_TRUE(zeroStringValue.getValue().empty());
   EXPECT_TRUE(zeroStringValue.getType() == stringTy);
 }
 


        


More information about the Mlir-commits mailing list