[Mlir-commits] [mlir] [mlir] Add concept of alias blocks (PR #65503)

Markus Böck llvmlistbot at llvm.org
Wed Sep 6 10:10:19 PDT 2023


https://github.com/zero9178 created https://github.com/llvm/llvm-project/pull/65503:

This PR is part of https://discourse.llvm.org/t/rfc-supporting-aliases-in-cyclic-types-and-attributes/73236

It implements the concept of "alias blocks", a block of alias definitions which may alias any other alias definitions within the block, regardless of definition order. This is purely a convenience for immutable attributes and types, but is a requirement for supporting aliasing definitions in cyclic mutable attributes and types.

The implementation works by first parsing an alias-block, which is simply subsequent alias definitions, in a syntax-only mode. This syntax-only mode only checks for syntactic validity of the parsed attribute or type but does not verify any parsed data. This allows us to essentially skip over alias definitions for the purpose of first collecting them and associating every alias definition with its source region.

In a second pass, we can start parsing the attributes and types while at the same time attempting to resolve any unknown alias references with our list of yet-to-be-parsed attributes and types, parsing them on demand if required.

A later PR will hook up this mechanism to the `tryStartCyclicParse` method added in https://github.com/llvm/llvm-project/commit/b121c266744d030120c59e6256559cbccacd3c6f to early register cyclic attributes and types, breaking the parsing cycles.

>From 5a84e4948d2059978c3f343e771286458c3f6965 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Tue, 5 Sep 2023 22:33:57 +0200
Subject: [PATCH] [mlir] Add concept of alias blocks

This PR is part of https://discourse.llvm.org/t/rfc-supporting-aliases-in-cyclic-types-and-attributes/73236

It implements the concept of "alias blocks", a block of alias definitions which may alias any other alias definitions within the block, regardless of definition order. This is purely a convenience for immutable attributes and types, but is a requirement for supporting aliasing definitions in cyclic mutable attributes and types.

The implementation works by first parsing an alias-block, which is simply subsequent alias definitions, in a syntax-only mode. This syntax-only mode only checks for syntactic validity of the parsed attribute or type but does not verify any parsed data. This allows us to essentially skip over alias definitions for the purpose of first collecting them and associating every alias definition with its source region.

In a second pass, we can start parsing the attributes and types while at the same time attempting to resolve any unknown alias references with our list of yet-to-be-parsed attributes and types, parsing them on demand if required.

A later PR will hook up this mechanism to the `tryStartCyclicParse` method added in https://github.com/llvm/llvm-project/commit/b121c266744d030120c59e6256559cbccacd3c6f to early register cyclic attributes and types, breaking the parsing cycles.
---
 mlir/docs/LangRef.md                       |  31 ++-
 mlir/lib/AsmParser/AttributeParser.cpp     |  48 ++++-
 mlir/lib/AsmParser/DialectSymbolParser.cpp |  41 +++-
 mlir/lib/AsmParser/LocationParser.cpp      |  17 +-
 mlir/lib/AsmParser/Parser.cpp              | 228 +++++++++++++++------
 mlir/lib/AsmParser/Parser.h                |  21 +-
 mlir/lib/AsmParser/ParserState.h           |  20 ++
 mlir/lib/AsmParser/TypeParser.cpp          |  25 ++-
 mlir/test/IR/alias-def-groups.mlir         |  25 +++
 9 files changed, 370 insertions(+), 86 deletions(-)
 create mode 100644 mlir/test/IR/alias-def-groups.mlir

diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index 0cfe845638c3cfb..a00f9f76c4253b2 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -183,12 +183,14 @@ starting with a `//` and going until the end of the line.
 
 ```
 // Top level production
-toplevel := (operation | attribute-alias-def | type-alias-def)*
+toplevel := (operation | alias-block-def)*
+alias-block-def := (attribute-alias-def | type-alias-def)*
 ```
 
 The production `toplevel` is the top level production that is parsed by any parsing
-consuming the MLIR syntax. [Operations](#operations),
-[Attribute aliases](#attribute-value-aliases), and [Type aliases](#type-aliases)
+consuming the MLIR syntax. [Operations](#operations) and 
+[Alias Blocks](#alias-block-definitions) consisting of 
+[Attribute aliases](#attribute-value-aliases) and [Type aliases](#type-aliases)
 can be declared on the toplevel.
 
 ### Identifiers and keywords
@@ -880,3 +882,26 @@ version using readAttribute and readType methods.
 There is no restriction on what kind of information a dialect is allowed to
 encode to model its versioning. Currently, versioning is supported only for
 bytecode formats.
+
+## Alias Block Definitions
+
+An alias block is a list of subsequent attribute or type alias definitions that
+are conceptually parsed as one unit.
+This allows any alias definition within the block to reference any other alias 
+definition within the block, regardless if defined lexically later or earlier in
+the block.
+
+```mlir
+// Alias block consisting of #array, !integer_type and #integer_attr.
+#array = [#integer_attr, !integer_type]
+!integer_type = i32
+#integer_attr = 8 : !integer_type
+
+// Illegal. !other_type is not part of this alias block and defined later 
+// in the file.
+!tuple = tuple<i32, !other_type>
+
+func.func @foo() { ... }
+
+!other_type = f32
+```
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 3437ac9addc5ff6..3453049a09bee05 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -145,6 +145,10 @@ Attribute Parser::parseAttribute(Type type) {
         parseLocationInstance(locAttr) ||
         parseToken(Token::r_paren, "expected ')' in inline location"))
       return Attribute();
+
+    if (syntaxOnly())
+      return state.syntaxOnlyAttr;
+
     return locAttr;
   }
 
@@ -430,6 +434,9 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
     return FloatAttr::get(floatType, *result);
   }
 
+  if (syntaxOnly())
+    return state.syntaxOnlyAttr;
+
   if (!isa<IntegerType, IndexType>(type))
     return emitError(loc, "integer literal not valid for specified type"),
            nullptr;
@@ -1003,7 +1010,9 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
   auto type = parseElementsLiteralType(attrType);
   if (!type)
     return nullptr;
-  return literalParser.getAttr(loc, type);
+  if (syntaxOnly())
+    return state.syntaxOnlyAttr;
+  return literalParser.getAttr(loc, cast<ShapedType>(type));
 }
 
 Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
@@ -1030,6 +1039,9 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
       return nullptr;
   }
 
+  if (syntaxOnly())
+    return state.syntaxOnlyAttr;
+
   ShapedType shapedType = dyn_cast<ShapedType>(attrType);
   if (!shapedType) {
     emitError(typeLoc, "`dense_resource` expected a shaped type");
@@ -1044,7 +1056,7 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
 ///   elements-literal-type ::= vector-type | ranked-tensor-type
 ///
 /// This method also checks the type has static shape.
-ShapedType Parser::parseElementsLiteralType(Type type) {
+Type Parser::parseElementsLiteralType(Type type) {
   // If the user didn't provide a type, parse the colon type for the literal.
   if (!type) {
     if (parseToken(Token::colon, "expected ':'"))
@@ -1053,6 +1065,9 @@ ShapedType Parser::parseElementsLiteralType(Type type) {
       return nullptr;
   }
 
+  if (syntaxOnly())
+    return state.syntaxOnlyType;
+
   auto sType = dyn_cast<ShapedType>(type);
   if (!sType) {
     emitError("elements literal must be a shaped type");
@@ -1077,17 +1092,23 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   // of the type.
   Type indiceEltType = builder.getIntegerType(64);
   if (consumeIf(Token::greater)) {
-    ShapedType type = parseElementsLiteralType(attrType);
+    Type type = parseElementsLiteralType(attrType);
     if (!type)
       return nullptr;
 
+    if (syntaxOnly())
+      return state.syntaxOnlyAttr;
+
     // Construct the sparse elements attr using zero element indice/value
     // attributes.
+    ShapedType shapedType = cast<ShapedType>(type);
     ShapedType indicesType =
-        RankedTensorType::get({0, type.getRank()}, indiceEltType);
-    ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
+        RankedTensorType::get({0, shapedType.getRank()}, indiceEltType);
+    ShapedType valuesType =
+        RankedTensorType::get({0}, shapedType.getElementType());
     return getChecked<SparseElementsAttr>(
-        loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
+        loc, shapedType,
+        DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
         DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
   }
 
@@ -1114,6 +1135,11 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   if (!type)
     return nullptr;
 
+  if (syntaxOnly())
+    return state.syntaxOnlyAttr;
+
+  ShapedType shapedType = cast<ShapedType>(type);
+
   // If the indices are a splat, i.e. the literal parser parsed an element and
   // not a list, we set the shape explicitly. The indices are represented by a
   // 2-dimensional shape where the second dimension is the rank of the type.
@@ -1121,7 +1147,8 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   // indice and thus one for the first dimension.
   ShapedType indicesType;
   if (indiceParser.getShape().empty()) {
-    indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
+    indicesType =
+        RankedTensorType::get({1, shapedType.getRank()}, indiceEltType);
   } else {
     // Otherwise, set the shape to the one parsed by the literal parser.
     indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
@@ -1131,7 +1158,7 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   // If the values are a splat, set the shape explicitly based on the number of
   // indices. The number of indices is encoded in the first dimension of the
   // indice shape type.
-  auto valuesEltType = type.getElementType();
+  auto valuesEltType = shapedType.getElementType();
   ShapedType valuesType =
       valuesParser.getShape().empty()
           ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
@@ -1139,7 +1166,7 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   auto values = valuesParser.getAttr(valuesLoc, valuesType);
 
   // Build the sparse elements attribute by the indices and values.
-  return getChecked<SparseElementsAttr>(loc, type, indices, values);
+  return getChecked<SparseElementsAttr>(loc, shapedType, indices, values);
 }
 
 Attribute Parser::parseStridedLayoutAttr() {
@@ -1260,6 +1287,9 @@ Attribute Parser::parseDistinctAttr(Type type) {
       return {};
   }
 
+  if (syntaxOnly())
+    return state.syntaxOnlyAttr;
+
   // Add the distinct attribute to the parser state, if it has not been parsed
   // before. Otherwise, check if the parsed reference attribute matches the one
   // found in the parser state.
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 2b1b114b90e86af..5330fbc9996ff9a 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/SourceMgr.h"
 
 using namespace mlir;
@@ -156,9 +157,11 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
 }
 
 /// Parse an extended dialect symbol.
-template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
+template <typename Symbol, typename SymbolAliasMap, typename ParseAliasFn,
+          typename CreateFn>
 static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
                                   SymbolAliasMap &aliases,
+                                  ParseAliasFn &parseAliasFn,
                                   CreateFn &&createSymbol) {
   Token tok = p.getToken();
 
@@ -185,12 +188,32 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
   // If there is no '<' token following this, and if the typename contains no
   // dot, then we are parsing a symbol alias.
   if (!hasTrailingData && !isPrettyName) {
+
+    // Don't check the validity of alias reference in syntax-only mode.
+    if (p.syntaxOnly()) {
+      if constexpr (std::is_same_v<Symbol, Type>)
+        return p.getState().syntaxOnlyType;
+      else
+        return p.getState().syntaxOnlyAttr;
+    }
+
     // Check for an alias for this type.
     auto aliasIt = aliases.find(identifier);
-    if (aliasIt == aliases.end())
-      return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
-                                    "'"),
-              nullptr);
+    if (aliasIt == aliases.end()) {
+      FailureOr<Symbol> symbol = failure();
+      // Try the parse alias function if set.
+      if (parseAliasFn)
+        symbol = parseAliasFn(identifier);
+
+      if (failed(symbol)) {
+        p.emitWrongTokenError("undefined symbol alias id '" + identifier + "'");
+        return nullptr;
+      }
+      if (!*symbol)
+        return nullptr;
+
+      aliasIt = aliases.insert({identifier, *symbol}).first;
+    }
     if (asmState) {
       if constexpr (std::is_same_v<Symbol, Type>)
         asmState->addTypeAliasUses(identifier, range);
@@ -241,12 +264,16 @@ Attribute Parser::parseExtendedAttr(Type type) {
   MLIRContext *ctx = getContext();
   Attribute attr = parseExtendedSymbol<Attribute>(
       *this, state.asmState, state.symbols.attributeAliasDefinitions,
+      state.symbols.parseUnknownAttributeAlias,
       [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
         // Parse an optional trailing colon type.
         Type attrType = type;
         if (consumeIf(Token::colon) && !(attrType = parseType()))
           return Attribute();
 
+        if (syntaxOnly())
+          return state.syntaxOnlyAttr;
+
         // If we found a registered dialect, then ask it to parse the attribute.
         if (Dialect *dialect =
                 builder.getContext()->getOrLoadDialect(dialectName)) {
@@ -288,7 +315,11 @@ Type Parser::parseExtendedType() {
   MLIRContext *ctx = getContext();
   return parseExtendedSymbol<Type>(
       *this, state.asmState, state.symbols.typeAliasDefinitions,
+      state.symbols.parseUnknownTypeAlias,
       [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
+        if (syntaxOnly())
+          return state.syntaxOnlyType;
+
         // If we found a registered dialect, then ask it to parse the type.
         if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
           // Temporarily reset the lexer to let the dialect parse the type.
diff --git a/mlir/lib/AsmParser/LocationParser.cpp b/mlir/lib/AsmParser/LocationParser.cpp
index 61b20179800c6cc..8139f188c32a740 100644
--- a/mlir/lib/AsmParser/LocationParser.cpp
+++ b/mlir/lib/AsmParser/LocationParser.cpp
@@ -53,6 +53,9 @@ ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) {
   if (parseToken(Token::r_paren, "expected ')' in callsite location"))
     return failure();
 
+  if (syntaxOnly())
+    return success();
+
   // Return the callsite location.
   loc = CallSiteLoc::get(calleeLoc, callerLoc);
   return success();
@@ -79,6 +82,9 @@ ParseResult Parser::parseFusedLocation(LocationAttr &loc) {
     LocationAttr newLoc;
     if (parseLocationInstance(newLoc))
       return failure();
+    if (syntaxOnly())
+      return success();
+
     locations.push_back(newLoc);
     return success();
   };
@@ -135,12 +141,15 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
     if (parseLocationInstance(childLoc))
       return failure();
 
-    loc = NameLoc::get(StringAttr::get(ctx, str), childLoc);
-
     // Parse the closing ')'.
     if (parseToken(Token::r_paren,
                    "expected ')' after child location of NameLoc"))
       return failure();
+
+    if (syntaxOnly())
+      return success();
+
+    loc = NameLoc::get(StringAttr::get(ctx, str), childLoc);
   } else {
     loc = NameLoc::get(StringAttr::get(ctx, str));
   }
@@ -154,6 +163,10 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
     Attribute locAttr = parseExtendedAttr(Type());
     if (!locAttr)
       return failure();
+
+    if (syntaxOnly())
+      return success();
+
     if (!(loc = dyn_cast<LocationAttr>(locAttr)))
       return emitError("expected location attribute, but got") << locAttr;
     return success();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 02bf9a418063991..738485fb4537210 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -27,6 +27,7 @@
 #include "llvm/Support/Endian.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Support/SourceMgr.h"
 #include <algorithm>
 #include <optional>
@@ -2403,17 +2404,12 @@ class TopLevelOperationParser : public Parser {
   ParseResult parse(Block *topLevelBlock, Location parserLoc);
 
 private:
-  /// Parse an attribute alias declaration.
+  /// Parse an alias block definition.
   ///
   ///   attribute-alias-def ::= '#' alias-name `=` attribute-value
-  ///
-  ParseResult parseAttributeAliasDef();
-
-  /// Parse a type alias declaration.
-  ///
   ///   type-alias-def ::= '!' alias-name `=` type
-  ///
-  ParseResult parseTypeAliasDef();
+  ///   alias-block-def ::= (type-alias-def | attribute-alias-def)+
+  ParseResult parseAliasBlockDef();
 
   /// Parse a top-level file metadata dictionary.
   ///
@@ -2514,69 +2510,184 @@ class ParsedResourceEntry : public AsmParsedResourceEntry {
   Token value;
   Parser &p;
 };
+
+/// Convenient subclass of `ParserState` which configures the parser for
+/// syntax-only parsing. This only copies config and other required state for
+/// parsing but does not copy side-effecting state such as the code completion
+/// context.
+struct SyntaxParserState : ParserState {
+  explicit SyntaxParserState(ParserState &state)
+      : ParserState(state.lex.getSourceMgr(), state.config, state.symbols,
+                    /*asmState=*/nullptr,
+                    /*codeCompleteContext=*/nullptr) {
+    syntaxOnly = true;
+  }
+};
+
 } // namespace
 
-ParseResult TopLevelOperationParser::parseAttributeAliasDef() {
-  assert(getToken().is(Token::hash_identifier));
-  StringRef aliasName = getTokenSpelling().drop_front();
+ParseResult TopLevelOperationParser::parseAliasBlockDef() {
 
-  // Check for redefinitions.
-  if (state.symbols.attributeAliasDefinitions.count(aliasName) > 0)
-    return emitError("redefinition of attribute alias id '" + aliasName + "'");
+  struct UnparsedData {
+    SMRange location;
+    StringRef text;
+  };
 
-  // Make sure this isn't invading the dialect attribute namespace.
-  if (aliasName.contains('.'))
-    return emitError("attribute names with a '.' are reserved for "
-                     "dialect-defined names");
+  // Use a map vector as StringMap has non-deterministic iteration order.
+  using StringMapVector =
+      llvm::MapVector<StringRef, UnparsedData, llvm::StringMap<unsigned>>;
 
-  SMRange location = getToken().getLocRange();
-  consumeToken(Token::hash_identifier);
+  StringMapVector unparsedAttributeAliases;
+  StringMapVector unparsedTypeAliases;
 
-  // Parse the '='.
-  if (parseToken(Token::equal, "expected '=' in attribute alias definition"))
-    return failure();
+  // Returns true if this alias has already been defined, either in this block
+  // or a previous one.
+  auto isRedefinition = [&](bool isType, StringRef aliasName) {
+    if (isType)
+      return state.symbols.typeAliasDefinitions.contains(aliasName) ||
+             unparsedTypeAliases.contains(aliasName);
 
-  // Parse the attribute value.
-  Attribute attr = parseAttribute();
-  if (!attr)
-    return failure();
+    return state.symbols.attributeAliasDefinitions.contains(aliasName) ||
+           unparsedAttributeAliases.contains(aliasName);
+  };
 
-  // Register this alias with the parser state.
-  if (state.asmState)
-    state.asmState->addAttrAliasDefinition(aliasName, location, attr);
-  state.symbols.attributeAliasDefinitions[aliasName] = attr;
-  return success();
-}
+  // Collect all attribute or type alias definitions in unparsed form first.
+  while (
+      getToken().isAny(Token::exclamation_identifier, Token::hash_identifier)) {
+    StringRef aliasName = getTokenSpelling().drop_front();
+
+    bool isType = getToken().is(mlir::Token::exclamation_identifier);
+    StringRef kind = isType ? "type" : "attribute";
+
+    // Check for redefinitions.
+    if (isRedefinition(isType, aliasName))
+      return emitError("redefinition of ")
+             << kind << " alias id '" << aliasName << "'";
+
+    // Make sure this isn't invading the dialect namespace.
+    if (aliasName.contains('.'))
+      return emitError(kind) << " names with a '.' are reserved for "
+                                "dialect-defined names";
+
+    SMRange location = getToken().getLocRange();
+    consumeToken();
+
+    // Parse the '='.
+    if (parseToken(Token::equal,
+                   "expected '=' in " + kind + " alias definition"))
+      return failure();
 
-ParseResult TopLevelOperationParser::parseTypeAliasDef() {
-  assert(getToken().is(Token::exclamation_identifier));
-  StringRef aliasName = getTokenSpelling().drop_front();
+    SyntaxParserState skippingParserState(state);
+    Parser syntaxOnlyParser(skippingParserState);
+    const char *start = getToken().getLoc().getPointer();
+    syntaxOnlyParser.resetToken(start);
 
-  // Check for redefinitions.
-  if (state.symbols.typeAliasDefinitions.count(aliasName) > 0)
-    return emitError("redefinition of type alias id '" + aliasName + "'");
+    // Parse just the syntax of the value, moving the lexer past the definition.
+    if (isType ? !syntaxOnlyParser.parseType()
+               : !syntaxOnlyParser.parseAttribute())
+      return failure();
+
+    // Get the location from the lexers new position.
+    const char *end = syntaxOnlyParser.getToken().getLoc().getPointer();
+    size_t length = end - start;
 
-  // Make sure this isn't invading the dialect type namespace.
-  if (aliasName.contains('.'))
-    return emitError("type names with a '.' are reserved for "
-                     "dialect-defined names");
+    StringMapVector &unparsedMap =
+        isType ? unparsedTypeAliases : unparsedAttributeAliases;
 
-  SMRange location = getToken().getLocRange();
-  consumeToken(Token::exclamation_identifier);
+    unparsedMap[aliasName] =
+        UnparsedData{location, StringRef(start, length).rtrim()};
+
+    // Move the top-level parser past the alias definition.
+    resetToken(end);
+  }
+
+  auto parseAttributeAlias = [&](StringRef aliasName,
+                                 const UnparsedData &unparsedData) {
+    llvm::SaveAndRestore<SetVector<const void *>> cyclicStack(
+        getState().cyclicParsingStack, {});
+    auto exit = saveAndResetToken(unparsedData.text.data());
+    Attribute attribute = parseAttribute();
+    if (!attribute)
+      return attribute;
+
+    // Register this alias with the parser state.
+    if (state.asmState)
+      state.asmState->addAttrAliasDefinition(aliasName, unparsedData.location,
+                                             attribute);
 
-  // Parse the '='.
-  if (parseToken(Token::equal, "expected '=' in type alias definition"))
+    return attribute;
+  };
+
+  auto parseTypeAlias = [&](StringRef aliasName,
+                            const UnparsedData &unparsedData) {
+    llvm::SaveAndRestore<SetVector<const void *>> cyclicStack(
+        getState().cyclicParsingStack, {});
+    auto exit = saveAndResetToken(unparsedData.text.data());
+    Type type = parseType();
+    if (!type)
+      return type;
+
+    // Register this alias with the parser state.
+    if (state.asmState)
+      state.asmState->addTypeAliasDefinition(aliasName, unparsedData.location,
+                                             type);
+
+    return type;
+  };
+
+  // Set the callbacks for the lazy parsing of alias definitions in the parser.
+  state.symbols.parseUnknownAttributeAlias =
+      [&](StringRef aliasName) -> FailureOr<Attribute> {
+    auto *iter = unparsedAttributeAliases.find(aliasName);
+    if (iter == unparsedAttributeAliases.end())
+      return failure();
+
+    return parseAttributeAlias(aliasName, iter->second);
+  };
+  state.symbols.parseUnknownTypeAlias =
+      [&](StringRef aliasName) -> FailureOr<Type> {
+    auto *iter = unparsedTypeAliases.find(aliasName);
+    if (iter == unparsedTypeAliases.end())
+      return failure();
+
+    return parseTypeAlias(aliasName, iter->second);
+  };
+
+  // Reset them to nullptr at the end. Keeping them around would lead to the
+  // access of local variables captured in this scope after we've returned from
+  // this function.
+  auto exit = llvm::make_scope_exit([&] {
+    state.symbols.parseUnknownTypeAlias = nullptr;
+    state.symbols.parseUnknownAttributeAlias = nullptr;
+  });
+
+  // Now go through all the unparsed definitions in the block and parse them.
+  // The order here is not significant for correctness, but should be
+  // deterministic. The order can also have an impact on the maximum stack usage
+  // during parsing. This can be improved in the future.
+  auto parse = [](auto &unparsed, auto &definitions, auto &parseFn) {
+    for (auto &&[aliasName, unparsedData] : unparsed) {
+      // Avoid parsing twice.
+      if (definitions.contains(aliasName))
+        continue;
+
+      auto symbol = parseFn(aliasName, unparsedData);
+      if (!symbol)
+        return failure();
+      definitions[aliasName] = symbol;
+    }
+    return success();
+  };
+
+  if (failed(parse(unparsedAttributeAliases,
+                   state.symbols.attributeAliasDefinitions,
+                   parseAttributeAlias)))
     return failure();
 
-  // Parse the type.
-  Type aliasedType = parseType();
-  if (!aliasedType)
+  if (failed(parse(unparsedTypeAliases, state.symbols.typeAliasDefinitions,
+                   parseTypeAlias)))
     return failure();
 
-  // Register this alias with the parser state.
-  if (state.asmState)
-    state.asmState->addTypeAliasDefinition(aliasName, location, aliasedType);
-  state.symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType);
   return success();
 }
 
@@ -2715,15 +2826,10 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
     case Token::error:
       return failure();
 
-    // Parse an attribute alias.
+    // Parse an alias def block.
     case Token::hash_identifier:
-      if (parseAttributeAliasDef())
-        return failure();
-      break;
-
-    // Parse a type alias.
     case Token::exclamation_identifier:
-      if (parseTypeAliasDef())
+      if (parseAliasBlockDef())
         return failure();
       break;
 
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 01c55f97a08c2ce..282c0b13aebd930 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -12,6 +12,7 @@
 #include "ParserState.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/ScopeExit.h"
 #include <optional>
 
 namespace mlir {
@@ -132,6 +133,22 @@ class Parser {
     state.curToken = state.lex.lexToken();
   }
 
+  /// Temporarily resets the parser to the given lexer position. The previous
+  /// lexer position is saved and restored on destruction of the returned
+  /// object.
+  [[nodiscard]] auto saveAndResetToken(const char *tokPos) {
+    const char *previous = getToken().getLoc().getPointer();
+    resetToken(tokPos);
+    return llvm::make_scope_exit([this, previous] { resetToken(previous); });
+  }
+
+  /// Returns true if the parser is in syntax-only mode. In this mode, the
+  /// parser only checks the syntactic validity of the parsed elements but does
+  /// not verify the correctness of the parsed data. Syntax-only mode is
+  /// currently only supported for attribute and type parsing and skips parsing
+  /// dialect attributes and types entirely.
+  bool syntaxOnly() { return state.syntaxOnly; }
+
   /// Consume the specified token if present and return success.  On failure,
   /// output a diagnostic and return failure.
   ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
@@ -209,7 +226,7 @@ class Parser {
   Type parseTupleType();
 
   /// Parse a vector type.
-  VectorType parseVectorType();
+  Type parseVectorType();
   ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
                                        SmallVectorImpl<bool> &scalableDims);
   ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
@@ -265,7 +282,7 @@ class Parser {
 
   /// Parse a dense elements attribute.
   Attribute parseDenseElementsAttr(Type attrType);
-  ShapedType parseElementsLiteralType(Type type);
+  Type parseElementsLiteralType(Type type);
 
   /// Parse a dense resource elements attribute.
   Attribute parseDenseResourceElementsAttr(Type attrType);
diff --git a/mlir/lib/AsmParser/ParserState.h b/mlir/lib/AsmParser/ParserState.h
index 1428ea3a82cee9f..17602f84a90833e 100644
--- a/mlir/lib/AsmParser/ParserState.h
+++ b/mlir/lib/AsmParser/ParserState.h
@@ -32,6 +32,15 @@ struct SymbolState {
   /// A map from type alias identifier to Type.
   llvm::StringMap<Type> typeAliasDefinitions;
 
+  /// Parser functions set during the parsing of alias-block-defs to parse an
+  /// unknown attribute or type alias. The parameter is the name of the alias.
+  /// The function should return failure if no such alias could be found.
+  /// If any errors occurred during parsing, a null attribute or type should
+  /// be returned.
+  llvm::unique_function<FailureOr<Attribute>(StringRef)>
+      parseUnknownAttributeAlias;
+  llvm::unique_function<FailureOr<Type>(StringRef)> parseUnknownTypeAlias;
+
   /// A map of dialect resource keys to the resolved resource name and handle
   /// to use during parsing.
   DenseMap<const OpAsmDialectInterface *,
@@ -88,6 +97,17 @@ struct ParserState {
   // popped when done. At the top-level we start with "builtin" as the
   // default, so that the top-level `module` operation parses as-is.
   SmallVector<StringRef> defaultDialectStack{"builtin"};
+
+  /// Controls whether the parser is in syntax-only mode.
+  bool syntaxOnly = false;
+
+  /// Attribute and type returned by `parseType`, `parseAttribute` and the more
+  /// specific parsing function to signal syntactic correctness if an attribute
+  /// or type cannot be created without verifying the parsed data as well.
+  /// Callers of such function should only check for null or not null return
+  /// values for error signaling.
+  Type syntaxOnlyType = NoneType::get(config.getContext());
+  Attribute syntaxOnlyAttr = UnitAttr::get(config.getContext());
 };
 
 } // namespace detail
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 306e850af27bc58..77e87cf9b0befb7 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -130,6 +130,10 @@ Type Parser::parseComplexType() {
   if (!elementType ||
       parseToken(Token::greater, "expected '>' in complex type"))
     return nullptr;
+
+  if (syntaxOnly())
+    return state.syntaxOnlyType;
+
   if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
     return emitError(elementTypeLoc, "invalid element type for complex"),
            nullptr;
@@ -150,6 +154,9 @@ Type Parser::parseFunctionType() {
       parseFunctionResultTypes(results))
     return nullptr;
 
+  if (syntaxOnly())
+    return state.syntaxOnlyType;
+
   return builder.getFunctionType(arguments, results);
 }
 
@@ -195,9 +202,10 @@ Type Parser::parseMemRefType() {
   if (!elementType)
     return nullptr;
 
-  // Check that memref is formed from allowed types.
-  if (!BaseMemRefType::isValidElementType(elementType))
-    return emitError(typeLoc, "invalid memref element type"), nullptr;
+  if (!syntaxOnly()) { // Check that memref is formed from allowed types.
+    if (!BaseMemRefType::isValidElementType(elementType))
+      return emitError(typeLoc, "invalid memref element type"), nullptr;
+  }
 
   MemRefLayoutAttrInterface layout;
   Attribute memorySpace;
@@ -208,6 +216,9 @@ Type Parser::parseMemRefType() {
     if (!attr)
       return failure();
 
+    if (syntaxOnly())
+      return success();
+
     if (isa<MemRefLayoutAttrInterface>(attr)) {
       layout = cast<MemRefLayoutAttrInterface>(attr);
     } else if (memorySpace) {
@@ -235,6 +246,9 @@ Type Parser::parseMemRefType() {
     }
   }
 
+  if (syntaxOnly())
+    return state.syntaxOnlyType;
+
   if (isUnranked)
     return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
 
@@ -437,7 +451,7 @@ Type Parser::parseTupleType() {
 /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
 /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
 ///
-VectorType Parser::parseVectorType() {
+Type Parser::parseVectorType() {
   consumeToken(Token::kw_vector);
 
   if (parseToken(Token::less, "expected '<' in vector type"))
@@ -458,6 +472,9 @@ VectorType Parser::parseVectorType() {
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
+  if (syntaxOnly())
+    return state.syntaxOnlyType;
+
   if (!VectorType::isValidElementType(elementType))
     return emitError(typeLoc, "vector elements must be int/index/float type"),
            nullptr;
diff --git a/mlir/test/IR/alias-def-groups.mlir b/mlir/test/IR/alias-def-groups.mlir
new file mode 100644
index 000000000000000..71d09371fae67a3
--- /dev/null
+++ b/mlir/test/IR/alias-def-groups.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -allow-unregistered-dialect -verify-diagnostics -split-input-file %s | FileCheck %s
+
+#array = [#integer_attr, !integer_type]
+!integer_type = i32
+#integer_attr = 8 : !integer_type
+
+// CHECK-LABEL: func @foo()
+func.func @foo() {
+  // CHECK-NEXT: value = [8 : i32, i32]
+  "foo.attr"() { value = #array} : () -> ()
+}
+
+// -----
+
+// Check that only groups may reference later defined aliases.
+
+// expected-error at below {{undefined symbol alias id 'integer_attr'}}
+#array = [!integer_type, #integer_attr]
+!integer_type = i32
+
+func.func @foo() {
+  %0 = "foo.attr"() { value = #array}
+}
+
+#integer_attr = 8 : !integer_type



More information about the Mlir-commits mailing list