[Mlir-commits] [mlir] ee09087 - Change the printing/parsing behavior for Attributes used in declarative assembly format

Mehdi Amini llvmlistbot at llvm.org
Tue Dec 7 18:02:46 PST 2021


Author: Mehdi Amini
Date: 2021-12-08T02:02:37Z
New Revision: ee0908703d2917d7310b71c5078fef44e8270317

URL: https://github.com/llvm/llvm-project/commit/ee0908703d2917d7310b71c5078fef44e8270317
DIFF: https://github.com/llvm/llvm-project/commit/ee0908703d2917d7310b71c5078fef44e8270317.diff

LOG: Change the printing/parsing behavior for Attributes used in declarative assembly format

The new form of printing attribute in the declarative assembly is eliding the `#dialect.mnemonic` prefix to only keep the `<....>` part.

Differential Revision: https://reviews.llvm.org/D113873

Added: 
    mlir/test/lib/Dialect/Test/TestDialect.td

Modified: 
    mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
    mlir/include/mlir/IR/DialectImplementation.h
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/Parser/AsmParserImpl.h
    mlir/test/Dialect/ArmSVE/roundtrip.mlir
    mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
    mlir/test/Dialect/Async/async-to-async-runtime.mlir
    mlir/test/Dialect/Async/runtime.mlir
    mlir/test/Dialect/Linalg/vectorization.mlir
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
    mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
    mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/lib/Dialect/Test/TestTypes.h
    mlir/test/mlir-tblgen/attr-or-type-format.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/test/mlir-tblgen/testdialect-typedefs.mlir
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
index b3d41401b1394..b0cfd5856e06e 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
@@ -31,6 +31,7 @@ def ArmSVE_Dialect : Dialect {
     vector operations, including a scalable vector type and intrinsics for
     some Arm SVE instructions.
   }];
+  let useDefaultTypePrinterParser = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -66,20 +67,6 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
     "Type":$elementType
   );
 
-  let printer = [{
-    $_printer << "<";
-    for (int64_t dim : getShape())
-      $_printer << dim << 'x';
-    $_printer << getElementType() << '>';
-  }];
-
-  let parser = [{
-    VectorType vector;
-    if ($_parser.parseType(vector))
-      return Type();
-    return get($_ctxt, vector.getShape(), vector.getElementType());
-  }];
-
   let extraClassDeclaration = [{
     bool hasStaticShape() const {
       return llvm::none_of(getShape(), ShapedType::isDynamic);

diff  --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index 1d0c4887e25e5..70afd029c94f9 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -64,19 +64,19 @@ struct FieldParser<
                                  AttributeT>> {
   static FailureOr<AttributeT> parse(AsmParser &parser) {
     AttributeT value;
-    if (parser.parseAttribute(value))
+    if (parser.parseCustomAttributeWithFallback(value))
       return failure();
     return value;
   }
 };
 
-/// Parse a type.
+/// Parse an attribute.
 template <typename TypeT>
 struct FieldParser<
     TypeT, std::enable_if_t<std::is_base_of<Type, TypeT>::value, TypeT>> {
   static FailureOr<TypeT> parse(AsmParser &parser) {
     TypeT value;
-    if (parser.parseType(value))
+    if (parser.parseCustomTypeWithFallback(value))
       return failure();
     return value;
   }

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 26212d397575e..7e2326ab395ea 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2984,6 +2984,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
               string baseCppClass = "::mlir::Type">
     : DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
       AttrOrTypeDef<"Type", name, traits, baseCppClass> {
+  // Make it possible to use such type as parameters for other types.
+  string cppType = dialect.cppNamespace # "::" # cppClassName;
+
   // A constant builder provided when the type has no parameters.
   let builderCall = !if(!empty(parameters),
                            "$_builder.getType<" # dialect.cppNamespace #

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index dab6e106f9512..609f6be55a2c9 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -50,6 +50,36 @@ class AsmPrinter {
   virtual void printType(Type type);
   virtual void printAttribute(Attribute attr);
 
+  /// Trait to check if `AttrType` provides a `print` method.
+  template <typename AttrOrType>
+  using has_print_method =
+      decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
+  template <typename AttrOrType>
+  using detect_has_print_method =
+      llvm::is_detected<has_print_method, AttrOrType>;
+
+  /// Print the provided attribute in the context of an operation custom
+  /// printer/parser: this will invoke directly the print method on the
+  /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
+  template <typename AttrOrType,
+            std::enable_if_t<detect_has_print_method<AttrOrType>::value>
+                *sfinae = nullptr>
+  void printStrippedAttrOrType(AttrOrType attrOrType) {
+    if (succeeded(printAlias(attrOrType)))
+      return;
+    attrOrType.print(*this);
+  }
+
+  /// SFINAE for printing the provided attribute in the context of an operation
+  /// custom printer in the case where the attribute does not define a print
+  /// method.
+  template <typename AttrOrType,
+            std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
+                *sfinae = nullptr>
+  void printStrippedAttrOrType(AttrOrType attrOrType) {
+    *this << attrOrType;
+  }
+
   /// Print the given attribute without its type. The corresponding parser must
   /// provide a valid type for the attribute.
   virtual void printAttributeWithoutType(Attribute attr);
@@ -102,6 +132,14 @@ class AsmPrinter {
   AsmPrinter(const AsmPrinter &) = delete;
   void operator=(const AsmPrinter &) = delete;
 
+  /// Print the alias for the given attribute, return failure if no alias could
+  /// be printed.
+  virtual LogicalResult printAlias(Attribute attr);
+
+  /// Print the alias for the given type, return failure if no alias could
+  /// be printed.
+  virtual LogicalResult printAlias(Type type);
+
   /// The internal implementation of the printer.
   Impl *impl;
 };
@@ -608,6 +646,13 @@ class AsmParser {
   /// Parse an arbitrary attribute of a given type and return it in result.
   virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
 
+  /// Parse a custom attribute with the provided callback, unless the next
+  /// token is `#`, in which case the generic parser is invoked.
+  virtual ParseResult parseCustomAttributeWithFallback(
+      Attribute &result, Type type,
+      function_ref<ParseResult(Attribute &result, Type type)>
+          parseAttribute) = 0;
+
   /// Parse an attribute of a specific kind and type.
   template <typename AttrType>
   ParseResult parseAttribute(AttrType &result, Type type = {}) {
@@ -639,9 +684,9 @@ class AsmParser {
     return parseAttribute(result, Type(), attrName, attrs);
   }
 
-  /// Parse an arbitrary attribute of a given type and return it in result. This
-  /// also adds the attribute to the specified attribute list with the specified
-  /// name.
+  /// Parse an arbitrary attribute of a given type and populate it in `result`.
+  /// This also adds the attribute to the specified attribute list with the
+  /// specified name.
   template <typename AttrType>
   ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
                              NamedAttrList &attrs) {
@@ -661,6 +706,82 @@ class AsmParser {
     return success();
   }
 
+  /// Trait to check if `AttrType` provides a `parse` method.
+  template <typename AttrType>
+  using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
+                                                    std::declval<Type>()));
+  template <typename AttrType>
+  using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
+
+  /// Parse a custom attribute of a given type unless the next token is `#`, in
+  /// which case the generic parser is invoked. The parsed attribute is
+  /// populated in `result` and also added to the specified attribute list with
+  /// the specified name.
+  template <typename AttrType>
+  std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
+  parseCustomAttributeWithFallback(AttrType &result, Type type,
+                                   StringRef attrName, NamedAttrList &attrs) {
+    llvm::SMLoc loc = getCurrentLocation();
+
+    // Parse any kind of attribute.
+    Attribute attr;
+    if (parseCustomAttributeWithFallback(
+            attr, type, [&](Attribute &result, Type type) -> ParseResult {
+              result = AttrType::parse(*this, type);
+              if (!result)
+                return failure();
+              return success();
+            }))
+      return failure();
+
+    // Check for the right kind of attribute.
+    result = attr.dyn_cast<AttrType>();
+    if (!result)
+      return emitError(loc, "invalid kind of attribute specified");
+
+    attrs.append(attrName, result);
+    return success();
+  }
+
+  /// SFINAE parsing method for Attribute that don't implement a parse method.
+  template <typename AttrType>
+  std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
+  parseCustomAttributeWithFallback(AttrType &result, Type type,
+                                   StringRef attrName, NamedAttrList &attrs) {
+    return parseAttribute(result, type, attrName, attrs);
+  }
+
+  /// Parse a custom attribute of a given type unless the next token is `#`, in
+  /// which case the generic parser is invoked. The parsed attribute is
+  /// populated in `result`.
+  template <typename AttrType>
+  std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
+  parseCustomAttributeWithFallback(AttrType &result) {
+    llvm::SMLoc loc = getCurrentLocation();
+
+    // Parse any kind of attribute.
+    Attribute attr;
+    if (parseCustomAttributeWithFallback(
+            attr, {}, [&](Attribute &result, Type type) -> ParseResult {
+              result = AttrType::parse(*this, type);
+              return success(!!result);
+            }))
+      return failure();
+
+    // Check for the right kind of attribute.
+    result = attr.dyn_cast<AttrType>();
+    if (!result)
+      return emitError(loc, "invalid kind of attribute specified");
+    return success();
+  }
+
+  /// SFINAE parsing method for Attribute that don't implement a parse method.
+  template <typename AttrType>
+  std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
+  parseCustomAttributeWithFallback(AttrType &result) {
+    return parseAttribute(result);
+  }
+
   /// Parse an arbitrary optional attribute of a given type and return it in
   /// result.
   virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
@@ -740,6 +861,11 @@ class AsmParser {
   /// Parse a type.
   virtual ParseResult parseType(Type &result) = 0;
 
+  /// Parse a custom type with the provided callback, unless the next
+  /// token is `#`, in which case the generic parser is invoked.
+  virtual ParseResult parseCustomTypeWithFallback(
+      Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
+
   /// Parse an optional type.
   virtual OptionalParseResult parseOptionalType(Type &result) = 0;
 
@@ -753,7 +879,7 @@ class AsmParser {
     if (parseType(type))
       return failure();
 
-    // Check for the right kind of attribute.
+    // Check for the right kind of type.
     result = type.dyn_cast<TypeT>();
     if (!result)
       return emitError(loc, "invalid kind of type specified");
@@ -761,6 +887,44 @@ class AsmParser {
     return success();
   }
 
+  /// Trait to check if `TypeT` provides a `parse` method.
+  template <typename TypeT>
+  using type_has_parse_method =
+      decltype(TypeT::parse(std::declval<AsmParser &>()));
+  template <typename TypeT>
+  using detect_type_has_parse_method =
+      llvm::is_detected<type_has_parse_method, TypeT>;
+
+  /// Parse a custom Type of a given type unless the next token is `#`, in
+  /// which case the generic parser is invoked. The parsed Type is
+  /// populated in `result`.
+  template <typename TypeT>
+  std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
+  parseCustomTypeWithFallback(TypeT &result) {
+    llvm::SMLoc loc = getCurrentLocation();
+
+    // Parse any kind of Type.
+    Type type;
+    if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
+          result = TypeT::parse(*this);
+          return success(!!result);
+        }))
+      return failure();
+
+    // Check for the right kind of Type.
+    result = type.dyn_cast<TypeT>();
+    if (!result)
+      return emitError(loc, "invalid kind of Type specified");
+    return success();
+  }
+
+  /// SFINAE parsing method for Type that don't implement a parse method.
+  template <typename TypeT>
+  std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
+  parseCustomTypeWithFallback(TypeT &result) {
+    return parseType(result);
+  }
+
   /// Parse a type list.
   ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
     do {
@@ -792,7 +956,7 @@ class AsmParser {
     if (parseColonType(type))
       return failure();
 
-    // Check for the right kind of attribute.
+    // Check for the right kind of type.
     result = type.dyn_cast<TypeType>();
     if (!result)
       return emitError(loc, "invalid kind of type specified");

diff  --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index 9f6632164b04e..482f0c351449b 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -53,21 +53,21 @@ void ArmSVEDialect::initialize() {
 // ScalableVectorType
 //===----------------------------------------------------------------------===//
 
-Type ArmSVEDialect::parseType(DialectAsmParser &parser) const {
-  llvm::SMLoc typeLoc = parser.getCurrentLocation();
-  {
-    Type genType;
-    auto parseResult = generatedTypeParser(parser, "vector", genType);
-    if (parseResult.hasValue())
-      return genType;
-  }
-  parser.emitError(typeLoc, "unknown type in ArmSVE dialect");
-  return Type();
+void ScalableVectorType::print(AsmPrinter &printer) const {
+  printer << "<";
+  for (int64_t dim : getShape())
+    printer << dim << 'x';
+  printer << getElementType() << '>';
 }
 
-void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
-  if (failed(generatedTypePrinter(type, os)))
-    llvm_unreachable("unexpected 'arm_sve' type kind");
+Type ScalableVectorType::parse(AsmParser &parser) {
+  SmallVector<int64_t> dims;
+  Type eltType;
+  if (parser.parseLess() ||
+      parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
+      parser.parseType(eltType) || parser.parseGreater())
+    return {};
+  return ScalableVectorType::get(eltType.getContext(), dims, eltType);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 1b18b19df7e82..bbe7b7413f2cd 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -170,7 +170,7 @@ static constexpr const CombiningKind combiningKindsList[] = {
 };
 
 void CombiningKindAttr::print(AsmPrinter &printer) const {
-  printer << "kind<";
+  printer << "<";
   auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
     return bitEnumContains(this->getKind(), kind);
   });
@@ -215,10 +215,12 @@ Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
 
 void VectorDialect::printAttribute(Attribute attr,
                                    DialectAsmPrinter &os) const {
-  if (auto ck = attr.dyn_cast<CombiningKindAttr>())
+  if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
+    os << "kind";
     ck.print(os);
-  else
-    llvm_unreachable("Unknown attribute type");
+    return;
+  }
+  llvm_unreachable("Unknown attribute type");
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 6d50838cc9ead..61364758c641c 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1188,7 +1188,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
 /// Ex:
 /// ```
 ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
-///   %1 = vector.multi_reduction #vector.kind<add>, %0 [1]
+///   %1 = vector.multi_reduction add, %0 [1]
 ///     : vector<8x32x16xf32> to vector<8x16xf32>
 /// ```
 /// Gets converted to:
@@ -1198,7 +1198,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
 ///    iterator_types = ["parallel", "parallel", "reduction"],
-///    kind = #vector.kind<add>} %0, %arg1, %cst_f0
+///    kind = add} %0, %arg1, %cst_f0
 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
 ///  ```
 struct MultiReduceToContract
@@ -1247,7 +1247,7 @@ struct MultiReduceToContract
 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
 ///    iterator_types = ["parallel", "parallel", "reduction"],
-///    kind = #vector.kind<add>} %0, %arg1, %cst_f0
+///    kind = add} %0, %arg1, %cst_f0
 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
 /// ```
 /// Gets converted to:
@@ -1257,7 +1257,7 @@ struct MultiReduceToContract
 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
 ///    iterator_types = ["parallel", "parallel", "reduction"],
-///    kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
+///    kind = add} %arg0, %arg1, %cst_f0
 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
 ///  ```
 struct CombineContractTranspose
@@ -1304,7 +1304,7 @@ struct CombineContractTranspose
 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
 ///    iterator_types = ["parallel", "parallel", "reduction"],
-///    kind = #vector.kind<add>} %0, %arg1, %cst_f0
+///    kind = add} %0, %arg1, %cst_f0
 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
 /// ```
 /// Gets converted to:
@@ -1314,7 +1314,7 @@ struct CombineContractTranspose
 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
 ///    iterator_types = ["parallel", "parallel", "reduction"],
-///    kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
+///    kind = add} %arg0, %arg1, %cst_f0
 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
 ///  ```
 struct CombineContractBroadcast

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 454a42e9f40b7..d3c239d9489f3 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -474,6 +474,14 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
   void printAttributeWithoutType(Attribute attr) override {
     printAttribute(attr);
   }
+  LogicalResult printAlias(Attribute attr) override {
+    initializer.visit(attr);
+    return success();
+  }
+  LogicalResult printAlias(Type type) override {
+    initializer.visit(type);
+    return success();
+  }
 
   /// Print the given set of attributes with names not included within
   /// 'elidedAttrs'.
@@ -1252,8 +1260,16 @@ class AsmPrinter::Impl {
   void printAttribute(Attribute attr,
                       AttrTypeElision typeElision = AttrTypeElision::Never);
 
+  /// Print the alias for the given attribute, return failure if no alias could
+  /// be printed.
+  LogicalResult printAlias(Attribute attr);
+
   void printType(Type type);
 
+  /// Print the alias for the given type, return failure if no alias could
+  /// be printed.
+  LogicalResult printAlias(Type type);
+
   /// Print the given location to the stream. If `allowAlias` is true, this
   /// allows for the internal location to use an attribute alias.
   void printLocation(LocationAttr loc, bool allowAlias = false);
@@ -1594,6 +1610,14 @@ static void printElidedElementsAttr(raw_ostream &os) {
   os << R"(opaque<"_", "0xDEADBEEF">)";
 }
 
+LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
+  return success(state && succeeded(state->getAliasState().getAlias(attr, os)));
+}
+
+LogicalResult AsmPrinter::Impl::printAlias(Type type) {
+  return success(state && succeeded(state->getAliasState().getAlias(type, os)));
+}
+
 void AsmPrinter::Impl::printAttribute(Attribute attr,
                                       AttrTypeElision typeElision) {
   if (!attr) {
@@ -1602,7 +1626,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
   }
 
   // Try to print an alias for this attribute.
-  if (state && succeeded(state->getAliasState().getAlias(attr, os)))
+  if (succeeded(printAlias(attr)))
     return;
 
   if (!isa<BuiltinDialect>(attr.getDialect()))
@@ -2104,6 +2128,16 @@ void AsmPrinter::printAttribute(Attribute attr) {
   impl->printAttribute(attr);
 }
 
+LogicalResult AsmPrinter::printAlias(Attribute attr) {
+  assert(impl && "expected AsmPrinter::printAlias to be overriden");
+  return impl->printAlias(attr);
+}
+
+LogicalResult AsmPrinter::printAlias(Type type) {
+  assert(impl && "expected AsmPrinter::printAlias to be overriden");
+  return impl->printAlias(type);
+}
+
 void AsmPrinter::printAttributeWithoutType(Attribute attr) {
   assert(impl &&
          "expected AsmPrinter::printAttributeWithoutType to be overriden");

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 211687166b501..38c7e165f343f 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -374,6 +374,7 @@ BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
 
 //===----------------------------------------------------------------------===//
 // BoolAttr
+//===----------------------------------------------------------------------===//
 
 bool BoolAttr::getValue() const {
   auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 33ed6b60932d4..10c38a86314fa 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -14,6 +14,7 @@
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TensorEncoding.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/BitVector.h"
@@ -633,7 +634,7 @@ bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
     return true;
 
   // Allow custom dialect attributes.
-  if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect()))
+  if (!isa<BuiltinDialect>(memorySpace.getDialect()))
     return true;
 
   return false;

diff  --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h
index 70039c2736a42..e8c70b7df965c 100644
--- a/mlir/lib/Parser/AsmParserImpl.h
+++ b/mlir/lib/Parser/AsmParserImpl.h
@@ -343,6 +343,29 @@ class AsmParserImpl : public BaseT {
     return success(static_cast<bool>(result));
   }
 
+  /// Parse a custom attribute with the provided callback, unless the next
+  /// token is `#`, in which case the generic parser is invoked.
+  ParseResult parseCustomAttributeWithFallback(
+      Attribute &result, Type type,
+      function_ref<ParseResult(Attribute &result, Type type)> parseAttribute)
+      override {
+    if (parser.getToken().isNot(Token::hash_identifier))
+      return parseAttribute(result, type);
+    result = parser.parseAttribute(type);
+    return success(static_cast<bool>(result));
+  }
+
+  /// Parse a custom attribute with the provided callback, unless the next
+  /// token is `#`, in which case the generic parser is invoked.
+  ParseResult parseCustomTypeWithFallback(
+      Type &result,
+      function_ref<ParseResult(Type &result)> parseType) override {
+    if (parser.getToken().isNot(Token::exclamation_identifier))
+      return parseType(result);
+    result = parser.parseType();
+    return success(static_cast<bool>(result));
+  }
+
   OptionalParseResult parseOptionalAttribute(Attribute &result,
                                              Type type) override {
     return parser.parseOptionalAttribute(result, type);

diff  --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 0c1da2741eb8b..6f4247f30625c 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -3,7 +3,7 @@
 func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
                    %b: !arm_sve.vector<16xi8>,
                    %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
-  // CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
+  // CHECK: arm_sve.sdot {{.*}}: <16xi8> to <4xi32
   %0 = arm_sve.sdot %c, %a, %b :
              !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
   return %0 : !arm_sve.vector<4xi32>
@@ -12,7 +12,7 @@ func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
 func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
                     %b: !arm_sve.vector<16xi8>,
                     %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
-  // CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
+  // CHECK: arm_sve.smmla {{.*}}: <16xi8> to <4xi3
   %0 = arm_sve.smmla %c, %a, %b :
              !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
   return %0 : !arm_sve.vector<4xi32>
@@ -21,7 +21,7 @@ func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
 func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
                    %b: !arm_sve.vector<16xi8>,
                    %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
-  // CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
+  // CHECK: arm_sve.udot {{.*}}: <16xi8> to <4xi32
   %0 = arm_sve.udot %c, %a, %b :
              !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
   return %0 : !arm_sve.vector<4xi32>
@@ -30,7 +30,7 @@ func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
 func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
                     %b: !arm_sve.vector<16xi8>,
                     %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
-  // CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
+  // CHECK: arm_sve.ummla {{.*}}: <16xi8> to <4xi3
   %0 = arm_sve.ummla %c, %a, %b :
              !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
   return %0 : !arm_sve.vector<4xi32>

diff  --git a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
index 9feb788402257..c570c584cbe67 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
@@ -16,7 +16,7 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
   %0 = arith.addf %arg0, %arg0 : f32
 // CHECK:   %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value<f32>
   %1 = async.runtime.create: !async.value<f32>
-// CHECK:   async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value<f32>
+// CHECK:   async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : <f32>
   async.runtime.store %0, %1: !async.value<f32>
 // CHECK:   async.runtime.set_available %[[VAL_STORAGE]] : !async.value<f32>
   async.runtime.set_available %1: !async.value<f32>
@@ -32,9 +32,9 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
 // CHECK:   cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
 
 // CHECK: ^[[BRANCH_OK]]:
-// CHECK:   %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : !async.value<f32>
+// CHECK:   %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : <f32>
 // CHECK:   %[[RETURNED:.*]] = arith.mulf %[[ARG]], %[[LOADED]] : f32
-// CHECK:   async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : !async.value<f32>
+// CHECK:   async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : <f32>
 // CHECK:   async.runtime.set_available %[[RETURNED_STORAGE]]
 // CHECK:   async.runtime.set_available %[[TOKEN]]
 // CHECK:   br ^[[CLEANUP]]
@@ -84,8 +84,8 @@ func @simple_caller() -> f32 {
 // CHECK:   cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]]
 
 // CHECK: ^[[BRANCH_VALUE_OK]]:
-// CHECK:   %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : !async.value<f32>
-// CHECK:   async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : !async.value<f32>
+// CHECK:   %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : <f32>
+// CHECK:   async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : <f32>
 // CHECK:   async.runtime.set_available %[[RETURNED_STORAGE]]
 // CHECK:   async.runtime.set_available %[[TOKEN]]
 // CHECK:   br ^[[CLEANUP]]
@@ -133,7 +133,7 @@ func @double_caller() -> f32 {
 // CHECK:   cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]]
 
 // CHECK: ^[[BRANCH_VALUE_OK_1]]:
-// CHECK:   %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : !async.value<f32>
+// CHECK:   %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : <f32>
 // CHECK:   %[[RETURNED_TO_CALLER_2:.*]]:2 = call @simple_callee(%[[LOADED_1]]) : (f32) -> (!async.token, !async.value<f32>)
 // CHECK:   %[[SAVED_2:.*]] = async.coro.save %[[HDL]]
 // CHECK:   async.runtime.await_and_resume %[[RETURNED_TO_CALLER_2]]#0, %[[HDL]]
@@ -150,8 +150,8 @@ func @double_caller() -> f32 {
 // CHECK:   cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]]
 
 // CHECK: ^[[BRANCH_VALUE_OK_2]]:
-// CHECK:   %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : !async.value<f32>
-// CHECK:   async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : !async.value<f32>
+// CHECK:   %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : <f32>
+// CHECK:   async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : <f32>
 // CHECK:   async.runtime.set_available %[[RETURNED_STORAGE]]
 // CHECK:   async.runtime.set_available %[[TOKEN]]
 // CHECK:   br ^[[CLEANUP]]

diff  --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index c024957dd7664..34532e56db7b7 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -245,7 +245,7 @@ func @execute_and_return_f32() -> f32 {
   }
 
   // CHECK: async.runtime.await %[[RET]]#1 : !async.value<f32>
-  // CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : !async.value<f32>
+  // CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : <f32>
   %0 = async.await %result : !async.value<f32>
 
   // CHECK: return %[[VALUE]]
@@ -323,7 +323,7 @@ func @async_value_operands() {
 
 // // Load from the async.value argument after error checking.
 // CHECK: ^[[CONTINUATION:.*]]:
-// CHECK:   %[[LOADED:.*]] = async.runtime.load %[[ARG]] : !async.value<f32
+// CHECK:   %[[LOADED:.*]] = async.runtime.load %[[ARG]] : <f32
 // CHECK:   arith.addf %[[LOADED]], %[[LOADED]] : f32
 // CHECK:   async.runtime.set_available %[[TOKEN]]
 

diff  --git a/mlir/test/Dialect/Async/runtime.mlir b/mlir/test/Dialect/Async/runtime.mlir
index 1daa4ea64b2fc..2841fa11d36bc 100644
--- a/mlir/test/Dialect/Async/runtime.mlir
+++ b/mlir/test/Dialect/Async/runtime.mlir
@@ -129,16 +129,16 @@ func @resume(%arg0: !async.coro.handle) {
 
 // CHECK-LABEL: @store
 func @store(%arg0: f32, %arg1: !async.value<f32>) {
-  // CHECK: async.runtime.store %arg0, %arg1 : !async.value<f32>
-  async.runtime.store %arg0, %arg1 : !async.value<f32>
+  // CHECK: async.runtime.store %arg0, %arg1 : <f32>
+  async.runtime.store %arg0, %arg1 : <f32>
   return
 }
 
 // CHECK-LABEL: @load
 func @load(%arg0: !async.value<f32>) -> f32 {
-  // CHECK: %0 = async.runtime.load %arg0 : !async.value<f32>
+  // CHECK: %0 = async.runtime.load %arg0 : <f32>
   // CHECK: return %0 : f32
-  %0 = async.runtime.load %arg0 : !async.value<f32>
+  %0 = async.runtime.load %arg0 : <f32>
   return %0 : f32
 }
 

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index b7ef524475487..934f8e3499365 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -6,7 +6,7 @@
 func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
 
 // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32>
-// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [0] : vector<1584xf32> to f32
+// CHECK: vector.multi_reduction <add>, %{{.*}} [0] : vector<1584xf32> to f32
 // CHECK: arith.addf %{{.*}}, %{{.*}} : f32
   linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
             outs(%C: memref<f32>)
@@ -19,7 +19,7 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32
 func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
 
 // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
-// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
+// CHECK: vector.multi_reduction <add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
 // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32>
   linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
             outs(%C: memref<1584xf32>)
@@ -31,7 +31,7 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: me
 // CHECK-LABEL: contraction_matmul
 func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
 // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
-// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
+// CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
 // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
   linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
             outs(%C: memref<1584x1584xf32>)
@@ -43,7 +43,7 @@ func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %
 // CHECK-LABEL: contraction_batch_matmul
 func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
 // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
-// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
+// CHECK: vector.multi_reduction <add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
 // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
   linalg.batch_matmul
     ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
@@ -71,7 +71,7 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
   //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
   //       CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
   //       CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
-  //       CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
+  //       CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
   //       CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
   //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
   linalg.generic #matmul_trait
@@ -105,7 +105,7 @@ func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
   //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
   //       CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
   //       CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
-  //       CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
+  //       CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
   //       CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
   //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32>
   linalg.generic #matmul_transpose_out_trait
@@ -139,7 +139,7 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
   //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32>
   //       CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
   //       CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32>
-  //       CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
+  //       CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
   //       CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32>
 
   //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
@@ -160,7 +160,7 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
 func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
   //       CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
-  //       CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
+  //       CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
   //       CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32>
   linalg.matmul
     ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
@@ -523,7 +523,7 @@ func @matmul_tensors(
   // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later
   // convert it to a 2D contract.
   //       CHECK:   %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32>
-  //       CHECK:   %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
+  //       CHECK:   %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
   //       CHECK:   %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32>
   //       CHECK:   %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
   %0 = linalg.matmul  ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
@@ -744,7 +744,7 @@ func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
   // CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
   // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
   // CHECK: math.exp {{.*}} : vector<4x16x8xf32>
-  // CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
+  // CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
   // CHECK: addf {{.*}} : vector<4x16xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
   // CHECK: return {{.*}} : tensor<4x16xf32>
@@ -779,7 +779,7 @@ func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: ten
   // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
   // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
   // CHECK: addf {{.*}} : vector<2x3x4x5xf32>
-  // CHECK: vector.multi_reduction #vector.kind<add>, {{.*}}  [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
+  // CHECK: vector.multi_reduction <add>, {{.*}}  [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
   // CHECK: addf {{.*}} : vector<2x5xf32>
   // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32>
   // CHECK: return {{.*}} : tensor<5x2xf32>
@@ -808,7 +808,7 @@ func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
   // CHECK: linalg.init_tensor [4] : tensor<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
-  // CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind<maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
+  // CHECK: %[[R:.+]] = vector.multi_reduction <maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
   // CHECK: maxf %[[R]], %[[CMINF]] : vector<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   %ident = arith.constant -3.40282e+38 : f32
@@ -833,7 +833,7 @@ func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: linalg.init_tensor [4] : tensor<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
-  // CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind<minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
+  // CHECK: %[[R:.+]] = vector.multi_reduction <minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
   // CHECK: arith.minf %[[R]], %[[CMAXF]] : vector<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   %maxf32 = arith.constant 3.40282e+38 : f32
@@ -857,7 +857,7 @@ func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: linalg.init_tensor [4] : tensor<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
-  // CHECK: vector.multi_reduction #vector.kind<mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
+  // CHECK: vector.multi_reduction <mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
   // CHECK: mulf {{.*}} : vector<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   %ident = arith.constant 1.0 : f32
@@ -881,7 +881,7 @@ func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   // CHECK: linalg.init_tensor [4] : tensor<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
-  // CHECK: vector.multi_reduction #vector.kind<or>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
+  // CHECK: vector.multi_reduction <or>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   %ident = arith.constant false
   %init = linalg.init_tensor [4] : tensor<4xi1>
@@ -904,7 +904,7 @@ func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   // CHECK: linalg.init_tensor [4] : tensor<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
-  // CHECK: vector.multi_reduction #vector.kind<and>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
+  // CHECK: vector.multi_reduction <and>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   %ident = arith.constant true
   %init = linalg.init_tensor [4] : tensor<4xi1>
@@ -927,7 +927,7 @@ func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   // CHECK: linalg.init_tensor [4] : tensor<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
-  // CHECK: vector.multi_reduction #vector.kind<xor>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
+  // CHECK: vector.multi_reduction <xor>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   %ident = arith.constant false
   %init = linalg.init_tensor [4] : tensor<4xi1>
@@ -979,7 +979,7 @@ func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) ->
   // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32>
   // CHECK: subf {{.*}} : vector<4x4xf32>
   // CHECK: math.exp {{.*}} : vector<4x4xf32>
-  // CHECK: vector.multi_reduction #vector.kind<add>, {{.*}} : vector<4x4xf32> to vector<4xf32>
+  // CHECK: vector.multi_reduction <add>, {{.*}} : vector<4x4xf32> to vector<4xf32>
   // CHECK: addf {{.*}} : vector<4xf32>
   // CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32>
   %c0 = arith.constant 0.0 : f32
@@ -1019,7 +1019,7 @@ func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
   //      CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
   // CHECK-SAME:   : tensor<32xf32>, vector<32xf32>
   //      CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>
-  //      CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind<add>, %[[r]] [0]
+  //      CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]] [0]
   // CHECK-SAME:   : vector<32xf32> to f32
   //      CHECK: %[[a:.*]] = arith.addf %[[red]], %[[f0]] : f32
   //      CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<f32>

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 3557cebae0af6..9b496f857b1ab 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1027,7 +1027,7 @@ func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v
 // CHECK-LABEL: func @vector_multi_reduction_single_parallel(
 //  CHECK-SAME:     %[[v:.*]]: vector<2xf32>
 func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> {
-    %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [] : vector<2xf32> to vector<2xf32>
+    %0 = vector.multi_reduction <mul>, %arg0 [] : vector<2xf32> to vector<2xf32>
 
 //       CHECK:     return %[[v]] : vector<2xf32>
     return %0 : vector<2xf32>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 63e30cf8912e7..195e720d2ad99 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -3,7 +3,7 @@
 // -----
 
 func @broadcast_to_scalar(%arg0: f32) -> f32 {
-  // expected-error at +1 {{'vector.broadcast' op result #0 must be vector of any type values, but got 'f32'}}
+  // expected-error at +1 {{custom op 'vector.broadcast' invalid kind of type specified}}
   %0 = vector.broadcast %arg0 : f32 to f32
 }
 
@@ -1022,7 +1022,7 @@ func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
 // -----
 
 func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{must be vector of any type values}}
+  // expected-error at +1 {{'vector.bitcast' invalid kind of type specified}}
   %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32
 }
 

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 43c5abdd9ef8c..a68d105186bfc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -685,9 +685,9 @@ func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>,
 
 // CHECK-LABEL: @multi_reduction
 func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 {
-  %1 = vector.multi_reduction #vector.kind<add>, %0 [1, 3] :
+  %1 = vector.multi_reduction <add>, %0 [1, 3] :
     vector<4x8x16x32xf32> to vector<4x16xf32>
-  %2 = vector.multi_reduction #vector.kind<add>, %1 [0, 1] :
+  %2 = vector.multi_reduction <add>, %1 [0, 1] :
     vector<4x16xf32> to f32
   return %2 : f32
 }

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 7e6c1713d455e..76936c5d06e9c 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
 
 func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
-    %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+    %0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
     return %0 : vector<2xf32>
 }
 // CHECK-LABEL: func @vector_multi_reduction
@@ -18,7 +18,7 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
 //       CHECK:       return %[[RESULT_VEC]]
 
 func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
-    %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0, 1] : vector<2x4xf32> to f32
+    %0 = vector.multi_reduction <mul>, %arg0 [0, 1] : vector<2x4xf32> to f32
     return %0 : f32
 }
 // CHECK-LABEL: func @vector_multi_reduction_to_scalar
@@ -30,7 +30,7 @@ func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
 //       CHECK:   return %[[RES]]
 
 func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
-    %0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+    %0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
     return %0 : vector<2x3xi32>
 }
 // CHECK-LABEL: func @vector_reduction_inner
@@ -66,7 +66,7 @@ func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
 
 
 func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> {
-    %0 = vector.multi_reduction #vector.kind<add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
+    %0 = vector.multi_reduction <add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
     return %0 : vector<2x5xf32>
 }
 
@@ -78,7 +78,7 @@ func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x
 //       CHECK:       return %[[RESULT]]
 
 func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> {
-    %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
+    %0 = vector.multi_reduction <mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
     return %0 : vector<2x4xf32>
 }
 // CHECK-LABEL: func @vector_multi_reduction_ordering

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 3f0e184274dce..f94a0b6e1a960 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s
 
 func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
-    %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+    %0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
     return %0 : vector<2xf32>
 }
 
@@ -18,7 +18,7 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
 func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
-    %0 = vector.multi_reduction #vector.kind<minf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+    %0 = vector.multi_reduction <minf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
     return %0 : vector<2xf32>
 }
 
@@ -35,7 +35,7 @@ func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
 func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
-    %0 = vector.multi_reduction #vector.kind<maxf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+    %0 = vector.multi_reduction <maxf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
     return %0 : vector<2xf32>
 }
 
@@ -52,7 +52,7 @@ func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
 func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
-    %0 = vector.multi_reduction #vector.kind<and>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+    %0 = vector.multi_reduction <and>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
     return %0 : vector<2xi32>
 }
 
@@ -69,7 +69,7 @@ func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xi32>
 
 func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> {
-    %0 = vector.multi_reduction #vector.kind<or>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+    %0 = vector.multi_reduction <or>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
     return %0 : vector<2xi32>
 }
 
@@ -86,7 +86,7 @@ func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> {
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xi32>
 
 func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
-    %0 = vector.multi_reduction #vector.kind<xor>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+    %0 = vector.multi_reduction <xor>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
     return %0 : vector<2xi32>
 }
 
@@ -104,7 +104,7 @@ func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
 
 
 func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
-    %0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+    %0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
     return %0 : vector<2x3xi32>
 }
 

diff  --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index ccab0617e241e..85389b2a767d6 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -12,7 +12,7 @@
 func @multidimreduction_contract(
   %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> {
   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
-  %1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
+  %1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
   return %1 : vector<8x16xf32>
 }
 
@@ -30,7 +30,7 @@ func @multidimreduction_contract(
 func @multidimreduction_contract_int(
   %arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> {
   %0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32>
-  %1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
+  %1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
   return %1 : vector<8x16xi32>
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index a9721cf31de81..8d1723adfc681 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -14,7 +14,7 @@
 #define TEST_ATTRDEFS
 
 // To get the test dialect definition.
-include "TestOps.td"
+include "TestDialect.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/SubElementInterfaces.td"
 
@@ -121,6 +121,29 @@ def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
   );
 }
 
+// A more complex parameterized attribute with multiple level of nesting.
+def CompoundNestedInner : Test_Attr<"CompoundNestedInner"> {
+  let mnemonic = "cmpnd_nested_inner";
+  // List of type parameters.
+  let parameters = (
+    ins
+    "int":$some_int,
+    CompoundAttrA:$cmpdA
+  );
+  let assemblyFormat = "`<` $some_int $cmpdA `>`";
+}
+
+def CompoundNestedOuter : Test_Attr<"CompoundNestedOuter"> {
+  let mnemonic = "cmpnd_nested_outer";
+
+  // List of type parameters.
+  let parameters = (
+    ins
+    CompoundNestedInner:$inner
+  );
+  let assemblyFormat = "`<` `i`  $inner `>`";
+}
+
 def TestParamOne : AttrParameter<"int64_t", ""> {}
 
 def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> {

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
new file mode 100644
index 0000000000000..756c7c10c92db
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -0,0 +1,46 @@
+//===-- TestDialect.td - Test dialect definition -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_DIALECT
+#define TEST_DIALECT
+
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+  let name = "test";
+  let cppNamespace = "::test";
+  let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
+  let hasCanonicalizer = 1;
+  let hasConstantMaterializer = 1;
+  let hasOperationAttrVerify = 1;
+  let hasRegionArgAttrVerify = 1;
+  let hasRegionResultAttrVerify = 1;
+  let hasOperationInterfaceFallback = 1;
+  let hasNonDefaultDestructor = 1;
+  let useDefaultAttributePrinterParser = 1;
+  let dependentDialects = ["::mlir::DLTIDialect"];
+
+  let extraClassDeclaration = [{
+    void registerAttributes();
+    void registerTypes();
+
+    // Provides a custom printing/parsing for some operations.
+    ::llvm::Optional<ParseOpHook>
+      getParseOperationHook(::llvm::StringRef opName) const override;
+    ::llvm::unique_function<void(::mlir::Operation *,
+                                 ::mlir::OpAsmPrinter &printer)>
+     getOperationPrinter(::mlir::Operation *op) const override;
+
+  private:
+    // Storage for a custom fallback interface.
+    void *fallbackEffectOpInterfaces;
+
+  }];
+}
+
+#endif // TEST_DIALECT

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2bea95017fa5f..4f6abed2eda26 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -9,6 +9,7 @@
 #ifndef TEST_OPS
 #define TEST_OPS
 
+include "TestDialect.td"
 include "mlir/Dialect/DLTI/DLTIBase.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/OpAsmInterface.td"
@@ -23,40 +24,11 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "TestInterfaces.td"
 
-def Test_Dialect : Dialect {
-  let name = "test";
-  let cppNamespace = "::test";
-  let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
-  let hasCanonicalizer = 1;
-  let hasConstantMaterializer = 1;
-  let hasOperationAttrVerify = 1;
-  let hasRegionArgAttrVerify = 1;
-  let hasRegionResultAttrVerify = 1;
-  let hasOperationInterfaceFallback = 1;
-  let hasNonDefaultDestructor = 1;
-  let useDefaultAttributePrinterParser = 1;
-  let dependentDialects = ["::mlir::DLTIDialect"];
-
-  let extraClassDeclaration = [{
-    void registerAttributes();
-    void registerTypes();
-
-    // Provides a custom printing/parsing for some operations.
-    ::llvm::Optional<ParseOpHook>
-      getParseOperationHook(::llvm::StringRef opName) const override;
-    ::llvm::unique_function<void(::mlir::Operation *,
-                                 ::mlir::OpAsmPrinter &printer)>
-     getOperationPrinter(::mlir::Operation *op) const override;
-
-  private:
-    // Storage for a custom fallback interface.
-    void *fallbackEffectOpInterfaces;
-
-  }];
-}
 
 // Include the attribute definitions.
 include "TestAttrDefs.td"
+// Include the type definitions.
+include "TestTypeDefs.td"
 
 
 class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
@@ -1933,6 +1905,16 @@ def FormatNestedAttr : TEST_Op<"format_nested_attr"> {
   let assemblyFormat = "$nested attr-dict-with-keyword";
 }
 
+def FormatNestedCompoundAttr : TEST_Op<"format_cpmd_nested_attr"> {
+  let arguments = (ins CompoundNestedOuter:$nested);
+  let assemblyFormat = "`nested` $nested attr-dict-with-keyword";
+}
+
+def FormatNestedType : TEST_Op<"format_cpmd_nested_type"> {
+  let arguments = (ins CompoundNestedOuterType:$nested);
+  let assemblyFormat = "$nested `nested` type($nested) attr-dict-with-keyword";
+}
+
 //===----------------------------------------------------------------------===//
 // Custom Directives
 

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 553fc54e9acae..7bd77e9ce42be 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -14,8 +14,9 @@
 #define TEST_TYPEDEFS
 
 // To get the test dialect def.
-include "TestOps.td"
+include "TestDialect.td"
 include "TestAttrDefs.td"
+include "TestInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
 
@@ -49,6 +50,29 @@ def CompoundTypeA : Test_Type<"CompoundA"> {
   }];
 }
 
+// A more complex and nested parameterized type.
+def CompoundNestedInnerType : Test_Type<"CompoundNestedInner"> {
+  let mnemonic = "cmpnd_inner";
+  // List of type parameters.
+  let parameters = (
+    ins
+    "int":$some_int,
+    CompoundTypeA:$cmpdA
+  );
+  let assemblyFormat = "`<` $some_int $cmpdA `>`";
+}
+
+def CompoundNestedOuterType : Test_Type<"CompoundNestedOuter"> {
+  let mnemonic = "cmpnd_nested_outer";
+
+  // List of type parameters.
+  let parameters = (
+    ins
+    CompoundNestedInnerType:$inner
+  );
+  let assemblyFormat = "`<` `i`  $inner `>`";
+}
+
 // An example of how one could implement a standard integer.
 def IntegerType : Test_Type<"TestInteger"> {
   let mnemonic = "int";

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 22a18c5461d50..0b050034a6b43 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -25,6 +25,7 @@
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 
 namespace test {
+class TestAttrWithFormatAttr;
 
 /// FieldInfo represents a field in the StructType data type. It is used as a
 /// parameter in TestTypeDefs.td.
@@ -63,13 +64,13 @@ struct FieldParser<test::CustomParam> {
     return test::CustomParam{value.getValue()};
   }
 };
-} // end namespace mlir
-
 inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer,
-                                    const test::CustomParam &param) {
+                                    test::CustomParam param) {
   return printer << param.value;
 }
 
+} // end namespace mlir
+
 #include "TestTypeInterfaces.h.inc"
 
 #define GET_TYPEDEF_CLASSES

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 888b9856da761..a0588bfb237df 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -61,7 +61,7 @@ def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
 // ATTR:   printer << ' ' << "hello";
 // ATTR:   printer << ' ' << "=";
 // ATTR:   printer << ' ';
-// ATTR:   printer << getValue();
+// ATTR:   printer.printStrippedAttrOrType(getValue());
 // ATTR:   printer << ",";
 // ATTR:   printer << ' ';
 // ATTR:   ::printAttrParamA(printer, getComplex());
@@ -154,10 +154,10 @@ def AttrB : TestAttr<"TestB"> {
 
 // ATTR: void TestFAttr::print(::mlir::AsmPrinter &printer) const {
 // ATTR:   printer << ' ';
-// ATTR:   printer << getV0();
+// ATTR:   printer.printStrippedAttrOrType(getV0());
 // ATTR:   printer << ",";
 // ATTR:   printer << ' ';
-// ATTR:   printer << getV1();
+// ATTR:   printer.printStrippedAttrOrType(getV1());
 // ATTR: }
 
 def AttrC : TestAttr<"TestF"> {
@@ -213,7 +213,7 @@ def AttrC : TestAttr<"TestF"> {
 // TYPE:   printer << ' ' << "bob";
 // TYPE:   printer << ' ' << "bar";
 // TYPE:   printer << ' ';
-// TYPE:   printer << getValue();
+// TYPE:   printer.printStrippedAttrOrType(getValue());
 // TYPE:   printer << ' ' << "complex";
 // TYPE:   printer << ' ' << "=";
 // TYPE:   printer << ' ';
@@ -361,21 +361,21 @@ def TypeB : TestType<"TestD"> {
 // TYPE:   printer << "v0";
 // TYPE:   printer << ' ' << "=";
 // TYPE:   printer << ' ';
-// TYPE:   printer << getV0();
+// TYPE:   printer.printStrippedAttrOrType(getV0());
 // TYPE:   printer << ",";
 // TYPE:   printer << ' ' << "v2";
 // TYPE:   printer << ' ' << "=";
 // TYPE:   printer << ' ';
-// TYPE:   printer << getV2();
+// TYPE:   printer.printStrippedAttrOrType(getV2());
 // TYPE:   printer << "v1";
 // TYPE:   printer << ' ' << "=";
 // TYPE:   printer << ' ';
-// TYPE:   printer << getV1();
+// TYPE:   printer.printStrippedAttrOrType(getV1());
 // TYPE:   printer << ",";
 // TYPE:   printer << ' ' << "v3";
 // TYPE:   printer << ' ' << "=";
 // TYPE:   printer << ' ';
-// TYPE:   printer << getV3();
+// TYPE:   printer.printStrippedAttrOrType(getV3());
 // TYPE: }
 
 def TypeC : TestType<"TestE"> {

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index dd996b8acc6bc..c3214c7afab4d 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -256,16 +256,50 @@ test.format_optional_else else
 // Format a custom attribute
 //===----------------------------------------------------------------------===//
 
-// CHECK: test.format_compound_attr #test.cmpnd_a<1, !test.smpla, [5, 6]>
-test.format_compound_attr #test.cmpnd_a<1, !test.smpla, [5, 6]>
+// CHECK: test.format_compound_attr <1, !test.smpla, [5, 6]>
+test.format_compound_attr <1, !test.smpla, [5, 6]>
 
-// CHECK:  module attributes {test.nested = #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>} {
+//-----
+
+
+// CHECK:   module attributes {test.nested = #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>} {
+module attributes {test.nested = #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>} {
+}
+
+//-----
+
+// Same as above, but fully spelling the inner attribute prefix `#test.cmpnd_a`.
+// CHECK:   module attributes {test.nested = #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>} {
 module attributes {test.nested = #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>} {
 }
 
-// CHECK: test.format_nested_attr #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>
+// CHECK: test.format_nested_attr <nested = <1, !test.smpla, [5, 6]>>
+test.format_nested_attr #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>
+
+//-----
+
+// Same as above, but fully spelling the inner attribute prefix `#test.cmpnd_a`.
+// CHECK: test.format_nested_attr <nested = <1, !test.smpla, [5, 6]>>
 test.format_nested_attr #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>
 
+//-----
+
+// CHECK: module attributes {test.someAttr = #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>}
+module attributes {test.someAttr = #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>}
+{
+}
+
+//-----
+
+// CHECK: module attributes {test.someAttr = #test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>}
+module attributes {test.someAttr = #test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>}
+{
+}
+
+//-----
+
+// CHECK: test.format_cpmd_nested_attr nested <i <42 <1, !test.smpla, [5, 6]>>>
+test.format_cpmd_nested_attr nested <i <42 <1, !test.smpla, [5, 6]>>>
 
 //===----------------------------------------------------------------------===//
 // Format custom directives

diff  --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
index 783b4be704c77..4ab6b0e86a279 100644
--- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
+++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
@@ -13,6 +13,22 @@ func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6]>)-> () {
   return
 }
 
+// CHECK: @compoundNested(%arg0: !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>)
+func @compoundNested(%arg0: !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>) -> () {
+  return
+}
+
+// Same as above, but we're parsing the complete spec for the inner type
+// CHECK: @compoundNestedExplicit(%arg0: !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>)
+func @compoundNestedExplicit(%arg0: !test.cmpnd_nested_outer<i !test.cmpnd_inner<42 <1, !test.smpla, [5, 6]>>>) -> () {
+// Verify that the type prefix is elided and optional
+// CHECK: format_cpmd_nested_type %arg0 nested <i <42 <1, !test.smpla, [5, 6]>>>
+// CHECK: format_cpmd_nested_type %arg0 nested <i <42 <1, !test.smpla, [5, 6]>>>
+  test.format_cpmd_nested_type %arg0 nested !test.cmpnd_nested_outer<i !test.cmpnd_inner<42 <1, !test.smpla, [5, 6]>>>
+  test.format_cpmd_nested_type %arg0 nested <i <42 <1, !test.smpla, [5, 6]>>>
+  return
+}
+
 // CHECK: @testInt(%arg0: !test.int<signed, 8>, %arg1: !test.int<unsigned, 2>, %arg2: !test.int<none, 1>)
 func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<n, 1>) {
   return

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 01ae08e57be9e..94e81f1085b5b 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -163,7 +163,8 @@ static const char *const defaultParameterParser =
     "::mlir::FieldParser<$0>::parse($_parser)";
 
 /// Default printer for attribute or type parameters.
-static const char *const defaultParameterPrinter = "$_printer << $_self";
+static const char *const defaultParameterPrinter =
+    "$_printer.printStrippedAttrOrType($_self)";
 
 /// Print an error when failing to parse an element.
 ///

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 0827146da180f..08785516b99f2 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -496,13 +496,25 @@ static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
 /// {0}: The name of the attribute.
 /// {1}: The type for the attribute.
 const char *const attrParserCode = R"(
-  if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes))
+  if (parser.parseCustomAttributeWithFallback({0}Attr, {1}, "{0}",
+          result.attributes)) {{
+    return ::mlir::failure();
+  }
+)";
+
+/// The code snippet used to generate a parser call for an attribute.
+///
+/// {0}: The name of the attribute.
+/// {1}: The type for the attribute.
+const char *const genericAttrParserCode = R"(
+  if (parser.parseAttribute({0}Attr, {1}, "{0}", result.attributes))
     return ::mlir::failure();
 )";
+
 const char *const optionalAttrParserCode = R"(
   {
     ::mlir::OptionalParseResult parseResult =
-      parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes);
+      parser.parseOptionalAttribute({0}Attr, {1}, "{0}", result.attributes);
     if (parseResult.hasValue() && failed(*parseResult))
       return ::mlir::failure();
   }
@@ -635,8 +647,12 @@ const char *const optionalTypeParserCode = R"(
   }
 )";
 const char *const typeParserCode = R"(
-  if (parser.parseType({0}RawTypes[0]))
-    return ::mlir::failure();
+  {
+    {0} type;
+    if (parser.parseCustomTypeWithFallback(type))
+      return ::mlir::failure();
+    {1}RawTypes[0] = type;
+  }
 )";
 
 /// The code snippet used to generate a parser call for a functional type.
@@ -1269,12 +1285,19 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
     std::string attrTypeStr;
     if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
       llvm::raw_string_ostream os(attrTypeStr);
-      os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
+      os << tgfmt(*typeBuilder, &attrTypeCtx);
+    } else {
+      attrTypeStr = "Type{}";
+    }
+    if (var->attr.isOptional()) {
+      body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
+    } else {
+      if (var->attr.getStorageType() == "::mlir::Attribute")
+        body << formatv(genericAttrParserCode, var->name, attrTypeStr);
+      else
+        body << formatv(attrParserCode, var->name, attrTypeStr);
     }
 
-    body << formatv(var->attr.isOptional() ? optionalAttrParserCode
-                                           : attrParserCode,
-                    var->name, attrTypeStr);
   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
     StringRef name = operand->getVar()->name;
@@ -1334,14 +1357,23 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
     ArgumentLengthKind lengthKind;
     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
-    if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
+    if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
       body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
-    else if (lengthKind == ArgumentLengthKind::Variadic)
+    } else if (lengthKind == ArgumentLengthKind::Variadic) {
       body << llvm::formatv(variadicTypeParserCode, listName);
-    else if (lengthKind == ArgumentLengthKind::Optional)
+    } else if (lengthKind == ArgumentLengthKind::Optional) {
       body << llvm::formatv(optionalTypeParserCode, listName);
-    else
-      body << formatv(typeParserCode, listName);
+    } else {
+      TypeSwitch<Element *>(dir->getOperand())
+          .Case<OperandVariable, ResultVariable>([&](auto operand) {
+            body << formatv(typeParserCode,
+                            operand->getVar()->constraint.getCPPClassName(),
+                            listName);
+          })
+          .Default([&](auto operand) {
+            body << formatv(typeParserCode, "Type", listName);
+          });
+    }
   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
     ArgumentLengthKind ignored;
     body << formatv(functionalTypeParserCode,
@@ -1761,7 +1793,8 @@ static void genVariadicRegionPrinter(const Twine &regionListName,
 
 /// Generate the C++ for an operand to a (*-)type directive.
 static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
-                                         MethodBody &body) {
+                                         MethodBody &body,
+                                         bool useArrayRef = true) {
   if (isa<OperandsDirective>(arg))
     return body << "getOperation()->getOperandTypes()";
   if (isa<ResultsDirective>(arg))
@@ -1778,8 +1811,10 @@ static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
                "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
                "::llvm::ArrayRef<::mlir::Type>())",
                op.getGetterName(var->name));
-  return body << "::llvm::ArrayRef<::mlir::Type>("
-              << op.getGetterName(var->name) << "().getType())";
+  if (useArrayRef)
+    return body << "::llvm::ArrayRef<::mlir::Type>("
+                << op.getGetterName(var->name) << "().getType())";
+  return body << op.getGetterName(var->name) << "().getType()";
 }
 
 /// Generate the printer for an enum attribute.
@@ -1978,9 +2013,15 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
     if (attr->getTypeBuilder())
       body << "  _odsPrinter.printAttributeWithoutType("
            << op.getGetterName(var->name) << "Attr());\n";
-    else
+    else if (var->attr.isOptional())
+      body << "_odsPrinter.printAttribute(" << op.getGetterName(var->name)
+           << "Attr());\n";
+    else if (var->attr.getStorageType() == "::mlir::Attribute")
       body << "  _odsPrinter.printAttribute(" << op.getGetterName(var->name)
            << "Attr());\n";
+    else
+      body << "_odsPrinter.printStrippedAttrOrType("
+           << op.getGetterName(var->name) << "Attr());\n";
   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
     if (operand->getVar()->isVariadicOfVariadic()) {
       body << "  ::llvm::interleaveComma("
@@ -2033,8 +2074,29 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
         return;
       }
     }
+    const NamedTypeConstraint *var = nullptr;
+    {
+      if (auto *operand = dyn_cast<OperandVariable>(dir->getOperand()))
+        var = operand->getVar();
+      else if (auto *operand = dyn_cast<ResultVariable>(dir->getOperand()))
+        var = operand->getVar();
+    }
+    if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
+        !var->isOptional()) {
+      std::string cppClass = var->constraint.getCPPClassName();
+      body << "  {\n"
+           << "    auto type = " << op.getGetterName(var->name)
+           << "().getType();\n"
+           << "    if (auto validType = type.dyn_cast<" << cppClass << ">())\n"
+           << "      _odsPrinter.printStrippedAttrOrType(validType);\n"
+           << "   else\n"
+           << "     _odsPrinter << type;\n"
+           << "  }\n";
+      return;
+    }
     body << "  _odsPrinter << ";
-    genTypeOperandPrinter(dir->getOperand(), op, body) << ";\n";
+    genTypeOperandPrinter(dir->getOperand(), op, body, /*useArrayRef=*/false)
+        << ";\n";
   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
     body << "  _odsPrinter.printFunctionalType(";
     genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";


        


More information about the Mlir-commits mailing list