[Mlir-commits] [mlir] [mlir] Add support for parsing and printing cyclic aliases (PR #66663)

Markus Böck llvmlistbot at llvm.org
Mon Sep 18 09:06:01 PDT 2023


https://github.com/zero9178 created https://github.com/llvm/llvm-project/pull/66663

Final part of https://discourse.llvm.org/t/rfc-supporting-aliases-in-cyclic-types-and-attributes/73236

Up until now, the printing of mutable attributes and types as alias were disabled entirely as parsing them would end up in an infinite recursion.
This PR fixes this issue by using the recently added `tryStartCyclicParse` function to registering a mutable attribute or type parsed as part of an alias definition as soon as its immutable key has been parsed.
This makes it possible to break the recursion cycle and make parsing succeed. Combined with a previous patch that made the parser insensitive to the order of aliases in a row, we can also enable the printing of mutable attributes and types.

Depends on https://github.com/llvm/llvm-project/pull/65503. While we lack stacked PRs, please review the first commit on that PR.

>From a6863faeb903f7020dc0be53b71388e6f394bd88 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Tue, 5 Sep 2023 22:33:57 +0200
Subject: [PATCH 1/2] [mlir] Add concept of alias blocks

This PR is part of https://discourse.llvm.org/t/rfc-supporting-aliases-in-cyclic-types-and-attributes/73236

It implements the concept of "alias blocks", a block of alias definitions which may alias any other alias definitions within the block, regardless of definition order. This is purely a convenience for immutable attributes and types, but is a requirement for supporting aliasing definitions in cyclic mutable attributes and types.

The implementation works by first parsing an alias-block, which is simply subsequent alias definitions, in a syntax-only mode. This syntax-only mode only checks for syntactic validity of the parsed attribute or type but does not verify any parsed data. This allows us to essentially skip over alias definitions for the purpose of first collecting them and associating every alias definition with its source region.

In a second pass, we can start parsing the attributes and types while at the same time attempting to resolve any unknown alias references with our list of yet-to-be-parsed attributes and types, parsing them on demand if required.

A later PR will hook up this mechanism to the `tryStartCyclicParse` method added in https://github.com/llvm/llvm-project/commit/b121c266744d030120c59e6256559cbccacd3c6f to early register cyclic attributes and types, breaking the parsing cycles.
---
 mlir/docs/LangRef.md                       |  31 ++-
 mlir/lib/AsmParser/AttributeParser.cpp     |  48 ++++-
 mlir/lib/AsmParser/DialectSymbolParser.cpp |  41 +++-
 mlir/lib/AsmParser/LocationParser.cpp      |  17 +-
 mlir/lib/AsmParser/Parser.cpp              | 228 +++++++++++++++------
 mlir/lib/AsmParser/Parser.h                |  21 +-
 mlir/lib/AsmParser/ParserState.h           |  20 ++
 mlir/lib/AsmParser/TypeParser.cpp          |  25 ++-
 mlir/test/IR/alias-def-groups.mlir         |  25 +++
 9 files changed, 370 insertions(+), 86 deletions(-)
 create mode 100644 mlir/test/IR/alias-def-groups.mlir

diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index 0cfe845638c3cfb..a00f9f76c4253b2 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -183,12 +183,14 @@ starting with a `//` and going until the end of the line.
 
 ```
 // Top level production
-toplevel := (operation | attribute-alias-def | type-alias-def)*
+toplevel := (operation | alias-block-def)*
+alias-block-def := (attribute-alias-def | type-alias-def)*
 ```
 
 The production `toplevel` is the top level production that is parsed by any parsing
-consuming the MLIR syntax. [Operations](#operations),
-[Attribute aliases](#attribute-value-aliases), and [Type aliases](#type-aliases)
+consuming the MLIR syntax. [Operations](#operations) and 
+[Alias Blocks](#alias-block-definitions) consisting of 
+[Attribute aliases](#attribute-value-aliases) and [Type aliases](#type-aliases)
 can be declared on the toplevel.
 
 ### Identifiers and keywords
@@ -880,3 +882,26 @@ version using readAttribute and readType methods.
 There is no restriction on what kind of information a dialect is allowed to
 encode to model its versioning. Currently, versioning is supported only for
 bytecode formats.
+
+## Alias Block Definitions
+
+An alias block is a list of subsequent attribute or type alias definitions that
+are conceptually parsed as one unit.
+This allows any alias definition within the block to reference any other alias 
+definition within the block, regardless if defined lexically later or earlier in
+the block.
+
+```mlir
+// Alias block consisting of #array, !integer_type and #integer_attr.
+#array = [#integer_attr, !integer_type]
+!integer_type = i32
+#integer_attr = 8 : !integer_type
+
+// Illegal. !other_type is not part of this alias block and defined later 
+// in the file.
+!tuple = tuple<i32, !other_type>
+
+func.func @foo() { ... }
+
+!other_type = f32
+```
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 3437ac9addc5ff6..3453049a09bee05 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -145,6 +145,10 @@ Attribute Parser::parseAttribute(Type type) {
         parseLocationInstance(locAttr) ||
         parseToken(Token::r_paren, "expected ')' in inline location"))
       return Attribute();
+
+    if (syntaxOnly())
+      return state.syntaxOnlyAttr;
+
     return locAttr;
   }
 
@@ -430,6 +434,9 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
     return FloatAttr::get(floatType, *result);
   }
 
+  if (syntaxOnly())
+    return state.syntaxOnlyAttr;
+
   if (!isa<IntegerType, IndexType>(type))
     return emitError(loc, "integer literal not valid for specified type"),
            nullptr;
@@ -1003,7 +1010,9 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
   auto type = parseElementsLiteralType(attrType);
   if (!type)
     return nullptr;
-  return literalParser.getAttr(loc, type);
+  if (syntaxOnly())
+    return state.syntaxOnlyAttr;
+  return literalParser.getAttr(loc, cast<ShapedType>(type));
 }
 
 Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
@@ -1030,6 +1039,9 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
       return nullptr;
   }
 
+  if (syntaxOnly())
+    return state.syntaxOnlyAttr;
+
   ShapedType shapedType = dyn_cast<ShapedType>(attrType);
   if (!shapedType) {
     emitError(typeLoc, "`dense_resource` expected a shaped type");
@@ -1044,7 +1056,7 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
 ///   elements-literal-type ::= vector-type | ranked-tensor-type
 ///
 /// This method also checks the type has static shape.
-ShapedType Parser::parseElementsLiteralType(Type type) {
+Type Parser::parseElementsLiteralType(Type type) {
   // If the user didn't provide a type, parse the colon type for the literal.
   if (!type) {
     if (parseToken(Token::colon, "expected ':'"))
@@ -1053,6 +1065,9 @@ ShapedType Parser::parseElementsLiteralType(Type type) {
       return nullptr;
   }
 
+  if (syntaxOnly())
+    return state.syntaxOnlyType;
+
   auto sType = dyn_cast<ShapedType>(type);
   if (!sType) {
     emitError("elements literal must be a shaped type");
@@ -1077,17 +1092,23 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   // of the type.
   Type indiceEltType = builder.getIntegerType(64);
   if (consumeIf(Token::greater)) {
-    ShapedType type = parseElementsLiteralType(attrType);
+    Type type = parseElementsLiteralType(attrType);
     if (!type)
       return nullptr;
 
+    if (syntaxOnly())
+      return state.syntaxOnlyAttr;
+
     // Construct the sparse elements attr using zero element indice/value
     // attributes.
+    ShapedType shapedType = cast<ShapedType>(type);
     ShapedType indicesType =
-        RankedTensorType::get({0, type.getRank()}, indiceEltType);
-    ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
+        RankedTensorType::get({0, shapedType.getRank()}, indiceEltType);
+    ShapedType valuesType =
+        RankedTensorType::get({0}, shapedType.getElementType());
     return getChecked<SparseElementsAttr>(
-        loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
+        loc, shapedType,
+        DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
         DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
   }
 
@@ -1114,6 +1135,11 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   if (!type)
     return nullptr;
 
+  if (syntaxOnly())
+    return state.syntaxOnlyAttr;
+
+  ShapedType shapedType = cast<ShapedType>(type);
+
   // If the indices are a splat, i.e. the literal parser parsed an element and
   // not a list, we set the shape explicitly. The indices are represented by a
   // 2-dimensional shape where the second dimension is the rank of the type.
@@ -1121,7 +1147,8 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   // indice and thus one for the first dimension.
   ShapedType indicesType;
   if (indiceParser.getShape().empty()) {
-    indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
+    indicesType =
+        RankedTensorType::get({1, shapedType.getRank()}, indiceEltType);
   } else {
     // Otherwise, set the shape to the one parsed by the literal parser.
     indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
@@ -1131,7 +1158,7 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   // If the values are a splat, set the shape explicitly based on the number of
   // indices. The number of indices is encoded in the first dimension of the
   // indice shape type.
-  auto valuesEltType = type.getElementType();
+  auto valuesEltType = shapedType.getElementType();
   ShapedType valuesType =
       valuesParser.getShape().empty()
           ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
@@ -1139,7 +1166,7 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   auto values = valuesParser.getAttr(valuesLoc, valuesType);
 
   // Build the sparse elements attribute by the indices and values.
-  return getChecked<SparseElementsAttr>(loc, type, indices, values);
+  return getChecked<SparseElementsAttr>(loc, shapedType, indices, values);
 }
 
 Attribute Parser::parseStridedLayoutAttr() {
@@ -1260,6 +1287,9 @@ Attribute Parser::parseDistinctAttr(Type type) {
       return {};
   }
 
+  if (syntaxOnly())
+    return state.syntaxOnlyAttr;
+
   // Add the distinct attribute to the parser state, if it has not been parsed
   // before. Otherwise, check if the parsed reference attribute matches the one
   // found in the parser state.
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 2b1b114b90e86af..5330fbc9996ff9a 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/SourceMgr.h"
 
 using namespace mlir;
@@ -156,9 +157,11 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
 }
 
 /// Parse an extended dialect symbol.
-template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
+template <typename Symbol, typename SymbolAliasMap, typename ParseAliasFn,
+          typename CreateFn>
 static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
                                   SymbolAliasMap &aliases,
+                                  ParseAliasFn &parseAliasFn,
                                   CreateFn &&createSymbol) {
   Token tok = p.getToken();
 
@@ -185,12 +188,32 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
   // If there is no '<' token following this, and if the typename contains no
   // dot, then we are parsing a symbol alias.
   if (!hasTrailingData && !isPrettyName) {
+
+    // Don't check the validity of alias reference in syntax-only mode.
+    if (p.syntaxOnly()) {
+      if constexpr (std::is_same_v<Symbol, Type>)
+        return p.getState().syntaxOnlyType;
+      else
+        return p.getState().syntaxOnlyAttr;
+    }
+
     // Check for an alias for this type.
     auto aliasIt = aliases.find(identifier);
-    if (aliasIt == aliases.end())
-      return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
-                                    "'"),
-              nullptr);
+    if (aliasIt == aliases.end()) {
+      FailureOr<Symbol> symbol = failure();
+      // Try the parse alias function if set.
+      if (parseAliasFn)
+        symbol = parseAliasFn(identifier);
+
+      if (failed(symbol)) {
+        p.emitWrongTokenError("undefined symbol alias id '" + identifier + "'");
+        return nullptr;
+      }
+      if (!*symbol)
+        return nullptr;
+
+      aliasIt = aliases.insert({identifier, *symbol}).first;
+    }
     if (asmState) {
       if constexpr (std::is_same_v<Symbol, Type>)
         asmState->addTypeAliasUses(identifier, range);
@@ -241,12 +264,16 @@ Attribute Parser::parseExtendedAttr(Type type) {
   MLIRContext *ctx = getContext();
   Attribute attr = parseExtendedSymbol<Attribute>(
       *this, state.asmState, state.symbols.attributeAliasDefinitions,
+      state.symbols.parseUnknownAttributeAlias,
       [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
         // Parse an optional trailing colon type.
         Type attrType = type;
         if (consumeIf(Token::colon) && !(attrType = parseType()))
           return Attribute();
 
+        if (syntaxOnly())
+          return state.syntaxOnlyAttr;
+
         // If we found a registered dialect, then ask it to parse the attribute.
         if (Dialect *dialect =
                 builder.getContext()->getOrLoadDialect(dialectName)) {
@@ -288,7 +315,11 @@ Type Parser::parseExtendedType() {
   MLIRContext *ctx = getContext();
   return parseExtendedSymbol<Type>(
       *this, state.asmState, state.symbols.typeAliasDefinitions,
+      state.symbols.parseUnknownTypeAlias,
       [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
+        if (syntaxOnly())
+          return state.syntaxOnlyType;
+
         // If we found a registered dialect, then ask it to parse the type.
         if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
           // Temporarily reset the lexer to let the dialect parse the type.
diff --git a/mlir/lib/AsmParser/LocationParser.cpp b/mlir/lib/AsmParser/LocationParser.cpp
index 61b20179800c6cc..8139f188c32a740 100644
--- a/mlir/lib/AsmParser/LocationParser.cpp
+++ b/mlir/lib/AsmParser/LocationParser.cpp
@@ -53,6 +53,9 @@ ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) {
   if (parseToken(Token::r_paren, "expected ')' in callsite location"))
     return failure();
 
+  if (syntaxOnly())
+    return success();
+
   // Return the callsite location.
   loc = CallSiteLoc::get(calleeLoc, callerLoc);
   return success();
@@ -79,6 +82,9 @@ ParseResult Parser::parseFusedLocation(LocationAttr &loc) {
     LocationAttr newLoc;
     if (parseLocationInstance(newLoc))
       return failure();
+    if (syntaxOnly())
+      return success();
+
     locations.push_back(newLoc);
     return success();
   };
@@ -135,12 +141,15 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
     if (parseLocationInstance(childLoc))
       return failure();
 
-    loc = NameLoc::get(StringAttr::get(ctx, str), childLoc);
-
     // Parse the closing ')'.
     if (parseToken(Token::r_paren,
                    "expected ')' after child location of NameLoc"))
       return failure();
+
+    if (syntaxOnly())
+      return success();
+
+    loc = NameLoc::get(StringAttr::get(ctx, str), childLoc);
   } else {
     loc = NameLoc::get(StringAttr::get(ctx, str));
   }
@@ -154,6 +163,10 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
     Attribute locAttr = parseExtendedAttr(Type());
     if (!locAttr)
       return failure();
+
+    if (syntaxOnly())
+      return success();
+
     if (!(loc = dyn_cast<LocationAttr>(locAttr)))
       return emitError("expected location attribute, but got") << locAttr;
     return success();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 84f44dba806df01..264348d32b281a7 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -29,6 +29,7 @@
 #include "llvm/Support/Endian.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Support/SourceMgr.h"
 #include <algorithm>
 #include <memory>
@@ -2417,17 +2418,12 @@ class TopLevelOperationParser : public Parser {
   ParseResult parse(Block *topLevelBlock, Location parserLoc);
 
 private:
-  /// Parse an attribute alias declaration.
+  /// Parse an alias block definition.
   ///
   ///   attribute-alias-def ::= '#' alias-name `=` attribute-value
-  ///
-  ParseResult parseAttributeAliasDef();
-
-  /// Parse a type alias declaration.
-  ///
   ///   type-alias-def ::= '!' alias-name `=` type
-  ///
-  ParseResult parseTypeAliasDef();
+  ///   alias-block-def ::= (type-alias-def | attribute-alias-def)+
+  ParseResult parseAliasBlockDef();
 
   /// Parse a top-level file metadata dictionary.
   ///
@@ -2528,69 +2524,184 @@ class ParsedResourceEntry : public AsmParsedResourceEntry {
   Token value;
   Parser &p;
 };
+
+/// Convenient subclass of `ParserState` which configures the parser for
+/// syntax-only parsing. This only copies config and other required state for
+/// parsing but does not copy side-effecting state such as the code completion
+/// context.
+struct SyntaxParserState : ParserState {
+  explicit SyntaxParserState(ParserState &state)
+      : ParserState(state.lex.getSourceMgr(), state.config, state.symbols,
+                    /*asmState=*/nullptr,
+                    /*codeCompleteContext=*/nullptr) {
+    syntaxOnly = true;
+  }
+};
+
 } // namespace
 
-ParseResult TopLevelOperationParser::parseAttributeAliasDef() {
-  assert(getToken().is(Token::hash_identifier));
-  StringRef aliasName = getTokenSpelling().drop_front();
+ParseResult TopLevelOperationParser::parseAliasBlockDef() {
 
-  // Check for redefinitions.
-  if (state.symbols.attributeAliasDefinitions.count(aliasName) > 0)
-    return emitError("redefinition of attribute alias id '" + aliasName + "'");
+  struct UnparsedData {
+    SMRange location;
+    StringRef text;
+  };
 
-  // Make sure this isn't invading the dialect attribute namespace.
-  if (aliasName.contains('.'))
-    return emitError("attribute names with a '.' are reserved for "
-                     "dialect-defined names");
+  // Use a map vector as StringMap has non-deterministic iteration order.
+  using StringMapVector =
+      llvm::MapVector<StringRef, UnparsedData, llvm::StringMap<unsigned>>;
 
-  SMRange location = getToken().getLocRange();
-  consumeToken(Token::hash_identifier);
+  StringMapVector unparsedAttributeAliases;
+  StringMapVector unparsedTypeAliases;
 
-  // Parse the '='.
-  if (parseToken(Token::equal, "expected '=' in attribute alias definition"))
-    return failure();
+  // Returns true if this alias has already been defined, either in this block
+  // or a previous one.
+  auto isRedefinition = [&](bool isType, StringRef aliasName) {
+    if (isType)
+      return state.symbols.typeAliasDefinitions.contains(aliasName) ||
+             unparsedTypeAliases.contains(aliasName);
 
-  // Parse the attribute value.
-  Attribute attr = parseAttribute();
-  if (!attr)
-    return failure();
+    return state.symbols.attributeAliasDefinitions.contains(aliasName) ||
+           unparsedAttributeAliases.contains(aliasName);
+  };
 
-  // Register this alias with the parser state.
-  if (state.asmState)
-    state.asmState->addAttrAliasDefinition(aliasName, location, attr);
-  state.symbols.attributeAliasDefinitions[aliasName] = attr;
-  return success();
-}
+  // Collect all attribute or type alias definitions in unparsed form first.
+  while (
+      getToken().isAny(Token::exclamation_identifier, Token::hash_identifier)) {
+    StringRef aliasName = getTokenSpelling().drop_front();
+
+    bool isType = getToken().is(mlir::Token::exclamation_identifier);
+    StringRef kind = isType ? "type" : "attribute";
+
+    // Check for redefinitions.
+    if (isRedefinition(isType, aliasName))
+      return emitError("redefinition of ")
+             << kind << " alias id '" << aliasName << "'";
+
+    // Make sure this isn't invading the dialect namespace.
+    if (aliasName.contains('.'))
+      return emitError(kind) << " names with a '.' are reserved for "
+                                "dialect-defined names";
+
+    SMRange location = getToken().getLocRange();
+    consumeToken();
+
+    // Parse the '='.
+    if (parseToken(Token::equal,
+                   "expected '=' in " + kind + " alias definition"))
+      return failure();
 
-ParseResult TopLevelOperationParser::parseTypeAliasDef() {
-  assert(getToken().is(Token::exclamation_identifier));
-  StringRef aliasName = getTokenSpelling().drop_front();
+    SyntaxParserState skippingParserState(state);
+    Parser syntaxOnlyParser(skippingParserState);
+    const char *start = getToken().getLoc().getPointer();
+    syntaxOnlyParser.resetToken(start);
 
-  // Check for redefinitions.
-  if (state.symbols.typeAliasDefinitions.count(aliasName) > 0)
-    return emitError("redefinition of type alias id '" + aliasName + "'");
+    // Parse just the syntax of the value, moving the lexer past the definition.
+    if (isType ? !syntaxOnlyParser.parseType()
+               : !syntaxOnlyParser.parseAttribute())
+      return failure();
+
+    // Get the location from the lexers new position.
+    const char *end = syntaxOnlyParser.getToken().getLoc().getPointer();
+    size_t length = end - start;
 
-  // Make sure this isn't invading the dialect type namespace.
-  if (aliasName.contains('.'))
-    return emitError("type names with a '.' are reserved for "
-                     "dialect-defined names");
+    StringMapVector &unparsedMap =
+        isType ? unparsedTypeAliases : unparsedAttributeAliases;
 
-  SMRange location = getToken().getLocRange();
-  consumeToken(Token::exclamation_identifier);
+    unparsedMap[aliasName] =
+        UnparsedData{location, StringRef(start, length).rtrim()};
+
+    // Move the top-level parser past the alias definition.
+    resetToken(end);
+  }
+
+  auto parseAttributeAlias = [&](StringRef aliasName,
+                                 const UnparsedData &unparsedData) {
+    llvm::SaveAndRestore<SetVector<const void *>> cyclicStack(
+        getState().cyclicParsingStack, {});
+    auto exit = saveAndResetToken(unparsedData.text.data());
+    Attribute attribute = parseAttribute();
+    if (!attribute)
+      return attribute;
+
+    // Register this alias with the parser state.
+    if (state.asmState)
+      state.asmState->addAttrAliasDefinition(aliasName, unparsedData.location,
+                                             attribute);
 
-  // Parse the '='.
-  if (parseToken(Token::equal, "expected '=' in type alias definition"))
+    return attribute;
+  };
+
+  auto parseTypeAlias = [&](StringRef aliasName,
+                            const UnparsedData &unparsedData) {
+    llvm::SaveAndRestore<SetVector<const void *>> cyclicStack(
+        getState().cyclicParsingStack, {});
+    auto exit = saveAndResetToken(unparsedData.text.data());
+    Type type = parseType();
+    if (!type)
+      return type;
+
+    // Register this alias with the parser state.
+    if (state.asmState)
+      state.asmState->addTypeAliasDefinition(aliasName, unparsedData.location,
+                                             type);
+
+    return type;
+  };
+
+  // Set the callbacks for the lazy parsing of alias definitions in the parser.
+  state.symbols.parseUnknownAttributeAlias =
+      [&](StringRef aliasName) -> FailureOr<Attribute> {
+    auto *iter = unparsedAttributeAliases.find(aliasName);
+    if (iter == unparsedAttributeAliases.end())
+      return failure();
+
+    return parseAttributeAlias(aliasName, iter->second);
+  };
+  state.symbols.parseUnknownTypeAlias =
+      [&](StringRef aliasName) -> FailureOr<Type> {
+    auto *iter = unparsedTypeAliases.find(aliasName);
+    if (iter == unparsedTypeAliases.end())
+      return failure();
+
+    return parseTypeAlias(aliasName, iter->second);
+  };
+
+  // Reset them to nullptr at the end. Keeping them around would lead to the
+  // access of local variables captured in this scope after we've returned from
+  // this function.
+  auto exit = llvm::make_scope_exit([&] {
+    state.symbols.parseUnknownTypeAlias = nullptr;
+    state.symbols.parseUnknownAttributeAlias = nullptr;
+  });
+
+  // Now go through all the unparsed definitions in the block and parse them.
+  // The order here is not significant for correctness, but should be
+  // deterministic. The order can also have an impact on the maximum stack usage
+  // during parsing. This can be improved in the future.
+  auto parse = [](auto &unparsed, auto &definitions, auto &parseFn) {
+    for (auto &&[aliasName, unparsedData] : unparsed) {
+      // Avoid parsing twice.
+      if (definitions.contains(aliasName))
+        continue;
+
+      auto symbol = parseFn(aliasName, unparsedData);
+      if (!symbol)
+        return failure();
+      definitions[aliasName] = symbol;
+    }
+    return success();
+  };
+
+  if (failed(parse(unparsedAttributeAliases,
+                   state.symbols.attributeAliasDefinitions,
+                   parseAttributeAlias)))
     return failure();
 
-  // Parse the type.
-  Type aliasedType = parseType();
-  if (!aliasedType)
+  if (failed(parse(unparsedTypeAliases, state.symbols.typeAliasDefinitions,
+                   parseTypeAlias)))
     return failure();
 
-  // Register this alias with the parser state.
-  if (state.asmState)
-    state.asmState->addTypeAliasDefinition(aliasName, location, aliasedType);
-  state.symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType);
   return success();
 }
 
@@ -2729,15 +2840,10 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
     case Token::error:
       return failure();
 
-    // Parse an attribute alias.
+    // Parse an alias def block.
     case Token::hash_identifier:
-      if (parseAttributeAliasDef())
-        return failure();
-      break;
-
-    // Parse a type alias.
     case Token::exclamation_identifier:
-      if (parseTypeAliasDef())
+      if (parseAliasBlockDef())
         return failure();
       break;
 
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 01c55f97a08c2ce..282c0b13aebd930 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -12,6 +12,7 @@
 #include "ParserState.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/ScopeExit.h"
 #include <optional>
 
 namespace mlir {
@@ -132,6 +133,22 @@ class Parser {
     state.curToken = state.lex.lexToken();
   }
 
+  /// Temporarily resets the parser to the given lexer position. The previous
+  /// lexer position is saved and restored on destruction of the returned
+  /// object.
+  [[nodiscard]] auto saveAndResetToken(const char *tokPos) {
+    const char *previous = getToken().getLoc().getPointer();
+    resetToken(tokPos);
+    return llvm::make_scope_exit([this, previous] { resetToken(previous); });
+  }
+
+  /// Returns true if the parser is in syntax-only mode. In this mode, the
+  /// parser only checks the syntactic validity of the parsed elements but does
+  /// not verify the correctness of the parsed data. Syntax-only mode is
+  /// currently only supported for attribute and type parsing and skips parsing
+  /// dialect attributes and types entirely.
+  bool syntaxOnly() { return state.syntaxOnly; }
+
   /// Consume the specified token if present and return success.  On failure,
   /// output a diagnostic and return failure.
   ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
@@ -209,7 +226,7 @@ class Parser {
   Type parseTupleType();
 
   /// Parse a vector type.
-  VectorType parseVectorType();
+  Type parseVectorType();
   ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
                                        SmallVectorImpl<bool> &scalableDims);
   ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
@@ -265,7 +282,7 @@ class Parser {
 
   /// Parse a dense elements attribute.
   Attribute parseDenseElementsAttr(Type attrType);
-  ShapedType parseElementsLiteralType(Type type);
+  Type parseElementsLiteralType(Type type);
 
   /// Parse a dense resource elements attribute.
   Attribute parseDenseResourceElementsAttr(Type attrType);
diff --git a/mlir/lib/AsmParser/ParserState.h b/mlir/lib/AsmParser/ParserState.h
index 1428ea3a82cee9f..17602f84a90833e 100644
--- a/mlir/lib/AsmParser/ParserState.h
+++ b/mlir/lib/AsmParser/ParserState.h
@@ -32,6 +32,15 @@ struct SymbolState {
   /// A map from type alias identifier to Type.
   llvm::StringMap<Type> typeAliasDefinitions;
 
+  /// Parser functions set during the parsing of alias-block-defs to parse an
+  /// unknown attribute or type alias. The parameter is the name of the alias.
+  /// The function should return failure if no such alias could be found.
+  /// If any errors occurred during parsing, a null attribute or type should
+  /// be returned.
+  llvm::unique_function<FailureOr<Attribute>(StringRef)>
+      parseUnknownAttributeAlias;
+  llvm::unique_function<FailureOr<Type>(StringRef)> parseUnknownTypeAlias;
+
   /// A map of dialect resource keys to the resolved resource name and handle
   /// to use during parsing.
   DenseMap<const OpAsmDialectInterface *,
@@ -88,6 +97,17 @@ struct ParserState {
   // popped when done. At the top-level we start with "builtin" as the
   // default, so that the top-level `module` operation parses as-is.
   SmallVector<StringRef> defaultDialectStack{"builtin"};
+
+  /// Controls whether the parser is in syntax-only mode.
+  bool syntaxOnly = false;
+
+  /// Attribute and type returned by `parseType`, `parseAttribute` and the more
+  /// specific parsing function to signal syntactic correctness if an attribute
+  /// or type cannot be created without verifying the parsed data as well.
+  /// Callers of such function should only check for null or not null return
+  /// values for error signaling.
+  Type syntaxOnlyType = NoneType::get(config.getContext());
+  Attribute syntaxOnlyAttr = UnitAttr::get(config.getContext());
 };
 
 } // namespace detail
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 306e850af27bc58..77e87cf9b0befb7 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -130,6 +130,10 @@ Type Parser::parseComplexType() {
   if (!elementType ||
       parseToken(Token::greater, "expected '>' in complex type"))
     return nullptr;
+
+  if (syntaxOnly())
+    return state.syntaxOnlyType;
+
   if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
     return emitError(elementTypeLoc, "invalid element type for complex"),
            nullptr;
@@ -150,6 +154,9 @@ Type Parser::parseFunctionType() {
       parseFunctionResultTypes(results))
     return nullptr;
 
+  if (syntaxOnly())
+    return state.syntaxOnlyType;
+
   return builder.getFunctionType(arguments, results);
 }
 
@@ -195,9 +202,10 @@ Type Parser::parseMemRefType() {
   if (!elementType)
     return nullptr;
 
-  // Check that memref is formed from allowed types.
-  if (!BaseMemRefType::isValidElementType(elementType))
-    return emitError(typeLoc, "invalid memref element type"), nullptr;
+  if (!syntaxOnly()) { // Check that memref is formed from allowed types.
+    if (!BaseMemRefType::isValidElementType(elementType))
+      return emitError(typeLoc, "invalid memref element type"), nullptr;
+  }
 
   MemRefLayoutAttrInterface layout;
   Attribute memorySpace;
@@ -208,6 +216,9 @@ Type Parser::parseMemRefType() {
     if (!attr)
       return failure();
 
+    if (syntaxOnly())
+      return success();
+
     if (isa<MemRefLayoutAttrInterface>(attr)) {
       layout = cast<MemRefLayoutAttrInterface>(attr);
     } else if (memorySpace) {
@@ -235,6 +246,9 @@ Type Parser::parseMemRefType() {
     }
   }
 
+  if (syntaxOnly())
+    return state.syntaxOnlyType;
+
   if (isUnranked)
     return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
 
@@ -437,7 +451,7 @@ Type Parser::parseTupleType() {
 /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
 /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
 ///
-VectorType Parser::parseVectorType() {
+Type Parser::parseVectorType() {
   consumeToken(Token::kw_vector);
 
   if (parseToken(Token::less, "expected '<' in vector type"))
@@ -458,6 +472,9 @@ VectorType Parser::parseVectorType() {
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
+  if (syntaxOnly())
+    return state.syntaxOnlyType;
+
   if (!VectorType::isValidElementType(elementType))
     return emitError(typeLoc, "vector elements must be int/index/float type"),
            nullptr;
diff --git a/mlir/test/IR/alias-def-groups.mlir b/mlir/test/IR/alias-def-groups.mlir
new file mode 100644
index 000000000000000..71d09371fae67a3
--- /dev/null
+++ b/mlir/test/IR/alias-def-groups.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -allow-unregistered-dialect -verify-diagnostics -split-input-file %s | FileCheck %s
+
+#array = [#integer_attr, !integer_type]
+!integer_type = i32
+#integer_attr = 8 : !integer_type
+
+// CHECK-LABEL: func @foo()
+func.func @foo() {
+  // CHECK-NEXT: value = [8 : i32, i32]
+  "foo.attr"() { value = #array} : () -> ()
+}
+
+// -----
+
+// Check that only groups may reference later defined aliases.
+
+// expected-error at below {{undefined symbol alias id 'integer_attr'}}
+#array = [!integer_type, #integer_attr]
+!integer_type = i32
+
+func.func @foo() {
+  %0 = "foo.attr"() { value = #array}
+}
+
+#integer_attr = 8 : !integer_type

>From 18fd994028fa1ef9e3ddad18c38eabc3f9ae4fdd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Wed, 6 Sep 2023 20:37:05 +0200
Subject: [PATCH 2/2] [mlir] Add support for parsing and printing cyclic
 aliases

Up until now, the printing of mutable attributes and types as alias were disabled entirely while parsing them would end up in an infinite recursion.
This PR fixes these issues by using the recently added `tryStartCyclicParse` function to registering a mutable attribute or type parsed as part of an alias definition as soon as its immutable key has been parsed.
This makes it possible to break the recursion cycle and make parsing succeed. Combined with a previous patch that made the parser insensitive to the order of aliases in a row, we can also enable the printing of mutable attributes and types.
---
 mlir/include/mlir/IR/OpImplementation.h       |  8 +--
 mlir/lib/AsmParser/AsmParserImpl.h            |  5 +-
 mlir/lib/AsmParser/AttributeParser.cpp        |  4 +-
 mlir/lib/AsmParser/DialectSymbolParser.cpp    | 31 ++++++++--
 mlir/lib/AsmParser/Parser.cpp                 |  8 +--
 mlir/lib/AsmParser/Parser.h                   | 10 ++--
 mlir/lib/AsmParser/ParserState.h              |  2 +-
 mlir/lib/AsmParser/TypeParser.cpp             |  8 +--
 mlir/lib/IR/AsmPrinter.cpp                    | 19 ++-----
 .../Dialect/LLVMIR/types-typed-pointers.mlir  |  2 +-
 mlir/test/Dialect/SPIRV/IR/types.mlir         | 10 ++++
 ...type.mlir => recursive-type-and-attr.mlir} | 24 +++++++-
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    | 17 ++++++
 mlir/test/lib/Dialect/Test/TestAttributes.cpp | 56 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestAttributes.h   | 31 ++++++++++
 .../Dialect/Test/TestDialectInterfaces.cpp    |  5 ++
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |  2 +-
 17 files changed, 194 insertions(+), 48 deletions(-)
 rename mlir/test/IR/{recursive-type.mlir => recursive-type-and-attr.mlir} (58%)

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 8864ef02cd3cbba..0893744c6a11960 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1365,7 +1365,7 @@ class AsmParser {
                           AttrOrTypeT> ||
             std::is_base_of_v<TypeTrait::IsMutable<AttrOrTypeT>, AttrOrTypeT>,
         "Only mutable attributes or types can be cyclic");
-    if (failed(pushCyclicParsing(attrOrType.getAsOpaquePointer())))
+    if (failed(pushCyclicParsing(attrOrType)))
       return failure();
 
     return CyclicParseReset(this);
@@ -1377,11 +1377,11 @@ class AsmParser {
   virtual FailureOr<AsmDialectResourceHandle>
   parseResourceHandle(Dialect *dialect) = 0;
 
-  /// Pushes a new attribute or type in the form of a type erased pointer
-  /// into an internal set.
+  /// Pushes a new attribute or type into an internal set.
   /// Returns success if the type or attribute was inserted in the set or
   /// failure if it was already contained.
-  virtual LogicalResult pushCyclicParsing(const void *opaquePointer) = 0;
+  virtual LogicalResult
+  pushCyclicParsing(PointerUnion<Attribute, Type> attrOrType) = 0;
 
   /// Removes the element that was last inserted with a successful call to
   /// `pushCyclicParsing`. There must be exactly one `popCyclicParsing` call
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 30c0079cda08611..1b88ca240c0eb21 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -570,8 +570,9 @@ class AsmParserImpl : public BaseT {
     return parser.parseXInDimensionList();
   }
 
-  LogicalResult pushCyclicParsing(const void *opaquePointer) override {
-    return success(parser.getState().cyclicParsingStack.insert(opaquePointer));
+  LogicalResult
+  pushCyclicParsing(PointerUnion<Attribute, Type> attrOrType) override {
+    return success(parser.getState().cyclicParsingStack.insert(attrOrType));
   }
 
   void popCyclicParsing() override {
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 3453049a09bee05..0ed385de8cf3b00 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -49,7 +49,7 @@ using namespace mlir::detail;
 ///                    | distinct-attribute
 ///                    | extended-attribute
 ///
-Attribute Parser::parseAttribute(Type type) {
+Attribute Parser::parseAttribute(Type type, StringRef aliasDefName) {
   switch (getToken().getKind()) {
   // Parse an AffineMap or IntegerSet attribute.
   case Token::kw_affine_map: {
@@ -117,7 +117,7 @@ Attribute Parser::parseAttribute(Type type) {
 
   // Parse an extended attribute, i.e. alias or dialect attribute.
   case Token::hash_identifier:
-    return parseExtendedAttr(type);
+    return parseExtendedAttr(type, aliasDefName);
 
   // Parse floating point and integer attributes.
   case Token::floatliteral:
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 5330fbc9996ff9a..2d78db35433ee5e 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -29,18 +29,37 @@ namespace {
 /// hooking into the main MLIR parsing logic.
 class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
 public:
-  CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
+  CustomDialectAsmParser(StringRef fullSpec, Parser &parser,
+                         StringRef aliasDefName)
       : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
-        fullSpec(fullSpec) {}
+        fullSpec(fullSpec), aliasDefName(aliasDefName) {}
   ~CustomDialectAsmParser() override = default;
 
   /// Returns the full specification of the symbol being parsed. This allows
   /// for using a separate parser if necessary.
   StringRef getFullSymbolSpec() const override { return fullSpec; }
 
+  LogicalResult
+  pushCyclicParsing(PointerUnion<Attribute, Type> attrOrType) override {
+    // If this is an alias definition, register the mutable attribute or type.
+    if (!aliasDefName.empty()) {
+      if (auto attr = dyn_cast<Attribute>(attrOrType))
+        parser.getState().symbols.attributeAliasDefinitions[aliasDefName] =
+            attr;
+      else
+        parser.getState().symbols.typeAliasDefinitions[aliasDefName] =
+            cast<Type>(attrOrType);
+    }
+    return AsmParserImpl::pushCyclicParsing(attrOrType);
+  }
+
 private:
   /// The full symbol specification.
   StringRef fullSpec;
+
+  /// If this parser is used to parse an alias definition, the name of the alias
+  /// definition. Empty otherwise.
+  StringRef aliasDefName;
 };
 } // namespace
 
@@ -260,7 +279,7 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
 ///                        | `#` alias-name pretty-dialect-sym-body? (`:` type)?
 ///   attribute-alias    ::= `#` alias-name
 ///
-Attribute Parser::parseExtendedAttr(Type type) {
+Attribute Parser::parseExtendedAttr(Type type, StringRef aliasDefName) {
   MLIRContext *ctx = getContext();
   Attribute attr = parseExtendedSymbol<Attribute>(
       *this, state.asmState, state.symbols.attributeAliasDefinitions,
@@ -282,7 +301,7 @@ Attribute Parser::parseExtendedAttr(Type type) {
           resetToken(symbolData.data());
 
           // Parse the attribute.
-          CustomDialectAsmParser customParser(symbolData, *this);
+          CustomDialectAsmParser customParser(symbolData, *this, aliasDefName);
           Attribute attr = dialect->parseAttribute(customParser, attrType);
           resetToken(curLexerPos);
           return attr;
@@ -311,7 +330,7 @@ Attribute Parser::parseExtendedAttr(Type type) {
 ///   dialect-type  ::= `!` alias-name pretty-dialect-attribute-body?
 ///   type-alias    ::= `!` alias-name
 ///
-Type Parser::parseExtendedType() {
+Type Parser::parseExtendedType(StringRef aliasDefName) {
   MLIRContext *ctx = getContext();
   return parseExtendedSymbol<Type>(
       *this, state.asmState, state.symbols.typeAliasDefinitions,
@@ -327,7 +346,7 @@ Type Parser::parseExtendedType() {
           resetToken(symbolData.data());
 
           // Parse the type.
-          CustomDialectAsmParser customParser(symbolData, *this);
+          CustomDialectAsmParser customParser(symbolData, *this, aliasDefName);
           Type type = dialect->parseType(customParser);
           resetToken(curLexerPos);
           return type;
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 264348d32b281a7..4bae575c1c82acb 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -2617,10 +2617,10 @@ ParseResult TopLevelOperationParser::parseAliasBlockDef() {
 
   auto parseAttributeAlias = [&](StringRef aliasName,
                                  const UnparsedData &unparsedData) {
-    llvm::SaveAndRestore<SetVector<const void *>> cyclicStack(
+    llvm::SaveAndRestore<SetVector<PointerUnion<Attribute, Type>>> cyclicStack(
         getState().cyclicParsingStack, {});
     auto exit = saveAndResetToken(unparsedData.text.data());
-    Attribute attribute = parseAttribute();
+    Attribute attribute = parseAttribute(Type(), aliasName);
     if (!attribute)
       return attribute;
 
@@ -2634,10 +2634,10 @@ ParseResult TopLevelOperationParser::parseAliasBlockDef() {
 
   auto parseTypeAlias = [&](StringRef aliasName,
                             const UnparsedData &unparsedData) {
-    llvm::SaveAndRestore<SetVector<const void *>> cyclicStack(
+    llvm::SaveAndRestore<SetVector<PointerUnion<Attribute, Type>>> cyclicStack(
         getState().cyclicParsingStack, {});
     auto exit = saveAndResetToken(unparsedData.text.data());
-    Type type = parseType();
+    Type type = parseType(aliasName);
     if (!type)
       return type;
 
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 282c0b13aebd930..0515d9bf956263d 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -202,13 +202,13 @@ class Parser {
   OptionalParseResult parseOptionalType(Type &type);
 
   /// Parse an arbitrary type.
-  Type parseType();
+  Type parseType(StringRef aliasDefName = "");
 
   /// Parse a complex type.
   Type parseComplexType();
 
   /// Parse an extended type.
-  Type parseExtendedType();
+  Type parseExtendedType(StringRef aliasDefName = "");
 
   /// Parse a function type.
   Type parseFunctionType();
@@ -217,7 +217,7 @@ class Parser {
   Type parseMemRefType();
 
   /// Parse a non function type.
-  Type parseNonFunctionType();
+  Type parseNonFunctionType(StringRef aliasDefName = "");
 
   /// Parse a tensor type.
   Type parseTensorType();
@@ -240,7 +240,7 @@ class Parser {
   //===--------------------------------------------------------------------===//
 
   /// Parse an arbitrary attribute with an optional type.
-  Attribute parseAttribute(Type type = {});
+  Attribute parseAttribute(Type type = {}, StringRef aliasDefName = "");
 
   /// Parse an optional attribute with the provided type.
   OptionalParseResult parseOptionalAttribute(Attribute &attribute,
@@ -271,7 +271,7 @@ class Parser {
   Attribute parseDistinctAttr(Type type);
 
   /// Parse an extended attribute.
-  Attribute parseExtendedAttr(Type type);
+  Attribute parseExtendedAttr(Type type, StringRef aliasDefName = "");
 
   /// Parse a float attribute.
   Attribute parseFloatAttr(Type type, bool isNegative);
diff --git a/mlir/lib/AsmParser/ParserState.h b/mlir/lib/AsmParser/ParserState.h
index 17602f84a90833e..0166116ff0ba3b9 100644
--- a/mlir/lib/AsmParser/ParserState.h
+++ b/mlir/lib/AsmParser/ParserState.h
@@ -82,7 +82,7 @@ struct ParserState {
 
   /// Stack of potentially cyclic mutable attributes or type currently being
   /// parsed.
-  SetVector<const void *> cyclicParsingStack;
+  SetVector<PointerUnion<Attribute, Type>> cyclicParsingStack;
 
   /// An optional pointer to a struct containing high level parser state to be
   /// populated during parsing.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 77e87cf9b0befb7..b24cf7b7021f0bf 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -58,10 +58,10 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
 ///   type ::= function-type
 ///          | non-function-type
 ///
-Type Parser::parseType() {
+Type Parser::parseType(StringRef aliasDefName) {
   if (getToken().is(Token::l_paren))
     return parseFunctionType();
-  return parseNonFunctionType();
+  return parseNonFunctionType(aliasDefName);
 }
 
 /// Parse a function result type.
@@ -273,7 +273,7 @@ Type Parser::parseMemRefType() {
 ///   float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
 ///   none-type ::= `none`
 ///
-Type Parser::parseNonFunctionType() {
+Type Parser::parseNonFunctionType(StringRef aliasDefName) {
   switch (getToken().getKind()) {
   default:
     return (emitWrongTokenError("expected non-function type"), nullptr);
@@ -356,7 +356,7 @@ Type Parser::parseNonFunctionType() {
 
   // extended type
   case Token::exclamation_identifier:
-    return parseExtendedType();
+    return parseExtendedType(aliasDefName);
 
   // Handle completion of a dialect type.
   case Token::code_complete:
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 7b0da30541b16a4..6b0a08494ddf7d6 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1192,21 +1192,10 @@ void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
     alias.print(p.getStream());
     p.getStream() << " = ";
 
-    if (alias.isTypeAlias()) {
-      // TODO: Support nested aliases in mutable types.
-      Type type = Type::getFromOpaquePointer(opaqueSymbol);
-      if (type.hasTrait<TypeTrait::IsMutable>())
-        p.getStream() << type;
-      else
-        p.printTypeImpl(type);
-    } else {
-      // TODO: Support nested aliases in mutable attributes.
-      Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol);
-      if (attr.hasTrait<AttributeTrait::IsMutable>())
-        p.getStream() << attr;
-      else
-        p.printAttributeImpl(attr);
-    }
+    if (alias.isTypeAlias())
+      p.printTypeImpl(Type::getFromOpaquePointer(opaqueSymbol));
+    else
+      p.printAttributeImpl(Attribute::getFromOpaquePointer(opaqueSymbol));
 
     p.getStream() << newLine;
   }
diff --git a/mlir/test/Dialect/LLVMIR/types-typed-pointers.mlir b/mlir/test/Dialect/LLVMIR/types-typed-pointers.mlir
index 2d63f379c2ee735..ac112f7745e514c 100644
--- a/mlir/test/Dialect/LLVMIR/types-typed-pointers.mlir
+++ b/mlir/test/Dialect/LLVMIR/types-typed-pointers.mlir
@@ -106,7 +106,7 @@ func.func @ptr_elem_interface(%arg0: !llvm.ptr<!test.smpla>) {
 !baz = i64
 !qux = !llvm.struct<(!baz)>
 
-!rec = !llvm.struct<"a", (ptr<struct<"a">>)>
+!rec = !llvm.struct<"a", (ptr<!rec>)>
 
 // CHECK: aliases
 llvm.func @aliases() {
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index e10a6fc77e8566d..ae50517eadd4213 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -425,6 +425,11 @@ func.func private @id_struct_redefinition(!spirv.struct<a8, (!spirv.ptr<!spirv.s
 // CHECK: func private @id_struct_recursive(!spirv.struct<a9, (!spirv.ptr<!spirv.struct<b9, (!spirv.ptr<!spirv.struct<a9>, Uniform>)>, Uniform>)>)
 func.func private @id_struct_recursive(!spirv.struct<a9, (!spirv.ptr<!spirv.struct<b9, (!spirv.ptr<!spirv.struct<a9>, Uniform>)>, Uniform>)>) -> ()
 
+!a = !spirv.struct<a9, (!spirv.ptr<!b, Uniform>)>
+!b = !spirv.struct<b9, (!spirv.ptr<!a, Uniform>)>
+// CHECK: func private @id_struct_recursive2(!spirv.struct<a9, (!spirv.ptr<!spirv.struct<b9, (!spirv.ptr<!spirv.struct<a9>, Uniform>)>, Uniform>)>)
+func.func private @id_struct_recursive2(!a) -> ()
+
 // -----
 
 // Equivalent to:
@@ -433,6 +438,11 @@ func.func private @id_struct_recursive(!spirv.struct<a9, (!spirv.ptr<!spirv.stru
 // CHECK: func private @id_struct_recursive(!spirv.struct<a10, (!spirv.ptr<!spirv.struct<b10, (!spirv.ptr<!spirv.struct<a10>, Uniform>, !spirv.ptr<!spirv.struct<b10>, Uniform>)>, Uniform>)>)
 func.func private @id_struct_recursive(!spirv.struct<a10, (!spirv.ptr<!spirv.struct<b10, (!spirv.ptr<!spirv.struct<a10>, Uniform>, !spirv.ptr<!spirv.struct<b10>, Uniform>)>, Uniform>)>) -> ()
 
+!a = !spirv.struct<a10, (!spirv.ptr<!b, Uniform>)>
+!b = !spirv.struct<b10, (!spirv.ptr<!a, Uniform>, !spirv.ptr<!b, Uniform>)>
+// CHECK: func private @id_struct_recursive2(!spirv.struct<a10, (!spirv.ptr<!spirv.struct<b10, (!spirv.ptr<!spirv.struct<a10>, Uniform>, !spirv.ptr<!spirv.struct<b10>, Uniform>)>, Uniform>)>)
+func.func private @id_struct_recursive2(!a) -> ()
+
 // -----
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type-and-attr.mlir
similarity index 58%
rename from mlir/test/IR/recursive-type.mlir
rename to mlir/test/IR/recursive-type-and-attr.mlir
index 121ba095573baa7..29a577bef4d0565 100644
--- a/mlir/test/IR/recursive-type.mlir
+++ b/mlir/test/IR/recursive-type-and-attr.mlir
@@ -1,8 +1,18 @@
 // RUN: mlir-opt %s -test-recursive-types | FileCheck %s
 
-// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
-// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
-// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
+// CHECK-DAG: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
+// CHECK-DAG: ![[$NAME:.*]] = !test.test_rec_alias<name, ![[$NAME]]>
+// CHECK-DAG: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<![[$NAME2]], i32>>
+// CHECK-DAG: #[[$ATTR:.*]] = #test.test_rec_alias<attr, #[[$ATTR]]>
+// CHECK-DAG: #[[$ATTR2:.*]] = #test.test_rec_alias<attr2, [#[[$ATTR2]], 5]>
+
+
+!name = !test.test_rec_alias<name, !name>
+!name2 = !test.test_rec_alias<name2, tuple<!name2, i32>>
+
+#attr = #test.test_rec_alias<attr, #attr>
+#array = [#attr2, 5]
+#attr2 = #test.test_rec_alias<attr2, #array>
 
 // CHECK-LABEL: @roundtrip
 func.func @roundtrip() {
@@ -24,6 +34,14 @@ func.func @roundtrip() {
   // CHECK: () -> ![[$NAME2]]
   "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
   "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
+
+  // Check that we can use these aliases, not just print them.
+  // CHECK: value = #[[$ATTR]]
+  // CHECK-SAME: () -> ![[$NAME]]
+  // CHECK-NEXT: value = #[[$ATTR2]]
+  // CHECK-SAME: () -> ![[$NAME2]]
+  "test.dummy_op_for_roundtrip"() { value = #attr } : () -> !name
+  "test.dummy_op_for_roundtrip"() { value = #attr2 } : () -> !name2
   return
 }
 
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index ec0a5548a160338..26d99218286b7b7 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -323,5 +323,22 @@ def Test_IteratorTypeArrayAttr
     : TypedArrayAttrBase<Test_IteratorTypeEnum,
   "Iterator type should be an enum.">;
 
+def TestRecursiveAliasAttr
+    : Test_Attr<"TestRecursiveAlias", [NativeAttrTrait<"IsMutable">]> {
+  let mnemonic = "test_rec_alias";
+  let storageClass = "TestRecursiveAttrStorage";
+  let storageNamespace = "test";
+  let genStorageClass = 0;
+
+  let parameters = (ins "llvm::StringRef":$name);
+
+  let hasCustomAssemblyFormat = 1;
+
+  let extraClassDeclaration = [{
+    Attribute getBody() const;
+
+    void setBody(Attribute attribute);
+  }];
+}
 
 #endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 7fc2e6ab3ec0a0a..e0e11fe7478f9f1 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -175,6 +175,62 @@ static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
   p << (*result ? "true" : "false");
 }
 
+//===----------------------------------------------------------------------===//
+// TestRecursiveAttr
+//===----------------------------------------------------------------------===//
+
+Attribute TestRecursiveAliasAttr::getBody() const { return getImpl()->body; }
+
+void TestRecursiveAliasAttr::setBody(Attribute attribute) {
+  (void)Base::mutate(attribute);
+}
+
+StringRef TestRecursiveAliasAttr::getName() const { return getImpl()->name; }
+
+Attribute TestRecursiveAliasAttr::parse(AsmParser &parser, Type type) {
+  StringRef name;
+  if (parser.parseLess() || parser.parseKeyword(&name))
+    return nullptr;
+  auto rec = TestRecursiveAliasAttr::get(parser.getContext(), name);
+
+  FailureOr<AsmParser::CyclicParseReset> cyclicParse =
+      parser.tryStartCyclicParse(rec);
+
+  // If this type already has been parsed above in the stack, expect just the
+  // name.
+  if (failed(cyclicParse)) {
+    if (failed(parser.parseGreater()))
+      return nullptr;
+    return rec;
+  }
+
+  // Otherwise, parse the body and update the type.
+  if (failed(parser.parseComma()))
+    return nullptr;
+  Attribute subAttr;
+  if (parser.parseAttribute(subAttr))
+    return nullptr;
+  if (!subAttr || failed(parser.parseGreater()))
+    return nullptr;
+
+  rec.setBody(subAttr);
+
+  return rec;
+}
+
+void TestRecursiveAliasAttr::print(AsmPrinter &printer) const {
+
+  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint =
+      printer.tryStartCyclicPrint(*this);
+
+  printer << "<" << getName();
+  if (succeeded(cyclicPrint)) {
+    printer << ", ";
+    printer << getBody();
+  }
+  printer << ">";
+}
+
 //===----------------------------------------------------------------------===//
 // Tablegen Generated Definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index cc73e078bf7e20b..d0f24f2738a4c22 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -33,6 +33,37 @@ class TestDialect;
 /// A handle used to reference external elements instances.
 using TestDialectResourceBlobHandle =
     mlir::DialectResourceBlobHandle<TestDialect>;
+
+/// Storage for simple named recursive attribute, where the attribute is
+/// identified by its name and can "contain" another attribute, including
+/// itself.
+struct TestRecursiveAttrStorage : public ::mlir::AttributeStorage {
+  using KeyTy = ::llvm::StringRef;
+
+  explicit TestRecursiveAttrStorage(::llvm::StringRef key) : name(key) {}
+
+  bool operator==(const KeyTy &other) const { return name == other; }
+
+  static TestRecursiveAttrStorage *
+  construct(::mlir::AttributeStorageAllocator &allocator, const KeyTy &key) {
+    return new (allocator.allocate<TestRecursiveAttrStorage>())
+        TestRecursiveAttrStorage(allocator.copyInto(key));
+  }
+
+  ::mlir::LogicalResult mutate(::mlir::AttributeStorageAllocator &allocator,
+                               ::mlir::Attribute newBody) {
+    // Cannot set a different body than before.
+    if (body && body != newBody)
+      return ::mlir::failure();
+
+    body = newBody;
+    return ::mlir::success();
+  }
+
+  ::llvm::StringRef name;
+  ::mlir::Attribute body;
+};
+
 } // namespace test
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 950af85007475b9..c34135e7f2792d1 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -169,6 +169,11 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
   //===------------------------------------------------------------------===//
 
   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
+    if (auto recAliasAttr = dyn_cast<TestRecursiveAliasAttr>(attr)) {
+      os << recAliasAttr.getName();
+      return AliasResult::FinalAlias;
+    }
+
     StringAttr strAttr = dyn_cast<StringAttr>(attr);
     if (!strAttr)
       return AliasResult::NoAlias;
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 2a8bdad8fb25d98..9b685d30302ff49 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -373,7 +373,7 @@ def TestI32 : Test_Type<"TestI32"> {
   let mnemonic = "i32";
 }
 
-def TestRecursiveAlias
+def TestRecursiveAliasType
     : Test_Type<"TestRecursiveAlias", [NativeTypeTrait<"IsMutable">]> {
   let mnemonic = "test_rec_alias";
   let storageClass = "TestRecursiveTypeStorage";



More information about the Mlir-commits mailing list