[Mlir-commits] [mlir] fe4f512 - [mlir:LSP] Add support for code completing attributes and types

River Riddle llvmlistbot at llvm.org
Fri Jul 8 16:56:29 PDT 2022


Author: River Riddle
Date: 2022-07-08T16:24:55-07:00
New Revision: fe4f512be7a57ea7bbaa36b5261c2fa00e306cf9

URL: https://github.com/llvm/llvm-project/commit/fe4f512be7a57ea7bbaa36b5261c2fa00e306cf9
DIFF: https://github.com/llvm/llvm-project/commit/fe4f512be7a57ea7bbaa36b5261c2fa00e306cf9.diff

LOG: [mlir:LSP] Add support for code completing attributes and types

This required changing a bit of how attributes/types are parsed. A new
`KeywordSwitch` class was added to AsmParser that provides a StringSwitch
like API for parsing keywords with a set of potential matches. It intends to
both provide a cleaner API, and enable injection for code completion. This
required changing the API of `generated(Attr|Type)Parser` to handle the
parsing of the keyword, instead of having the user do it. Most upstream
dialects use the autogenerated handling and didn't require a direct update.

Differential Revision: https://reviews.llvm.org/D129267

Added: 
    

Modified: 
    flang/lib/Optimizer/Dialect/FIRType.cpp
    mlir/docs/AttributesAndTypes.md
    mlir/include/mlir/IR/OpImplementation.h
    mlir/include/mlir/Parser/CodeComplete.h
    mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
    mlir/lib/Parser/AsmParserImpl.h
    mlir/lib/Parser/AttributeParser.cpp
    mlir/lib/Parser/DialectSymbolParser.cpp
    mlir/lib/Parser/Lexer.h
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Parser/Parser.h
    mlir/lib/Parser/TypeParser.cpp
    mlir/lib/Tools/lsp-server-support/Protocol.h
    mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
    mlir/test/lib/Dialect/Test/TestTypes.cpp
    mlir/test/mlir-lsp-server/completion.test
    mlir/test/mlir-tblgen/attrdefs.td
    mlir/test/mlir-tblgen/default-type-attr-print-parser.td
    mlir/test/mlir-tblgen/typedefs.td
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index 0a8a2eb0385b7..77075395576da 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -116,10 +116,8 @@ RecordType verifyDerived(mlir::AsmParser &parser, RecordType derivedTy,
 mlir::Type fir::parseFirType(FIROpsDialect *dialect,
                              mlir::DialectAsmParser &parser) {
   mlir::StringRef typeTag;
-  if (parser.parseKeyword(&typeTag))
-    return {};
   mlir::Type genType;
-  auto parseResult = generatedTypeParser(parser, typeTag, genType);
+  auto parseResult = generatedTypeParser(parser, &typeTag, genType);
   if (parseResult.hasValue())
     return genType;
   parser.emitError(parser.getNameLoc(), "unknown fir type: ") << typeTag;

diff  --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md
index 36c134ee9435b..e1ad13952fe0b 100644
--- a/mlir/docs/AttributesAndTypes.md
+++ b/mlir/docs/AttributesAndTypes.md
@@ -473,10 +473,10 @@ one for printing. These static functions placed alongside the class definitions
 and have the following function signatures:
 
 ```c++
-static ParseResult generatedAttributeParser(DialectAsmParser& parser, StringRef mnemonic, Type attrType, Attribute &result);
+static ParseResult generatedAttributeParser(DialectAsmParser& parser, StringRef *mnemonic, Type attrType, Attribute &result);
 static LogicalResult generatedAttributePrinter(Attribute attr, DialectAsmPrinter& printer);
 
-static ParseResult generatedTypeParser(DialectAsmParser& parser, StringRef mnemonic, Type &result);
+static ParseResult generatedTypeParser(DialectAsmParser& parser, StringRef *mnemonic, Type &result);
 static LogicalResult generatedTypePrinter(Type type, DialectAsmPrinter& printer);
 ```
 

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 33967b0e42943..cb240a40e0c79 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -571,43 +571,6 @@ class AsmParser {
   /// Parse a quoted string token if present.
   virtual ParseResult parseOptionalString(std::string *string) = 0;
 
-  /// Parse a given keyword.
-  ParseResult parseKeyword(StringRef keyword) {
-    return parseKeyword(keyword, "");
-  }
-  virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
-
-  /// Parse a keyword into 'keyword'.
-  ParseResult parseKeyword(StringRef *keyword) {
-    auto loc = getCurrentLocation();
-    if (parseOptionalKeyword(keyword))
-      return emitError(loc, "expected valid keyword");
-    return success();
-  }
-
-  /// Parse the given keyword if present.
-  virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
-
-  /// Parse a keyword, if present, into 'keyword'.
-  virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
-
-  /// Parse a keyword, if present, and if one of the 'allowedValues',
-  /// into 'keyword'
-  virtual ParseResult
-  parseOptionalKeyword(StringRef *keyword,
-                       ArrayRef<StringRef> allowedValues) = 0;
-
-  /// Parse a keyword or a quoted string.
-  ParseResult parseKeywordOrString(std::string *result) {
-    if (failed(parseOptionalKeywordOrString(result)))
-      return emitError(getCurrentLocation())
-             << "expected valid keyword or string";
-    return success();
-  }
-
-  /// Parse an optional keyword or string.
-  virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
-
   /// Parse a `(` token.
   virtual ParseResult parseLParen() = 0;
 
@@ -712,6 +675,115 @@ class AsmParser {
     return parseCommaSeparatedList(Delimiter::None, parseElementFn);
   }
 
+  //===--------------------------------------------------------------------===//
+  // Keyword Parsing
+  //===--------------------------------------------------------------------===//
+
+  /// This class represents a StringSwitch like class that is useful for parsing
+  /// expected keywords. On construction, it invokes `parseKeyword` and
+  /// processes each of the provided cases statements until a match is hit. The
+  /// provided `ResultT` must be assignable from `failure()`.
+  template <typename ResultT = ParseResult>
+  class KeywordSwitch {
+  public:
+    KeywordSwitch(AsmParser &parser)
+        : parser(parser), loc(parser.getCurrentLocation()) {
+      if (failed(parser.parseKeywordOrCompletion(&keyword)))
+        result = failure();
+    }
+
+    /// Case that uses the provided value when true.
+    KeywordSwitch &Case(StringLiteral str, ResultT value) {
+      return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
+    }
+    KeywordSwitch &Default(ResultT value) {
+      return Default([&](StringRef, SMLoc) { return std::move(value); });
+    }
+    /// Case that invokes the provided functor when true. The parameters passed
+    /// to the functor are the keyword, and the location of the keyword (in case
+    /// any errors need to be emitted).
+    template <typename FnT>
+    std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
+    Case(StringLiteral str, FnT &&fn) {
+      if (result)
+        return *this;
+
+      // If the word was empty, record this as a completion.
+      if (keyword.empty())
+        parser.codeCompleteExpectedTokens(str);
+      else if (keyword == str)
+        result.emplace(std::move(fn(keyword, loc)));
+      return *this;
+    }
+    template <typename FnT>
+    std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
+    Default(FnT &&fn) {
+      if (!result)
+        result.emplace(fn(keyword, loc));
+      return *this;
+    }
+
+    /// Returns true if this switch has a value yet.
+    bool hasValue() const { return result.hasValue(); }
+
+    /// Return the result of the switch.
+    LLVM_NODISCARD operator ResultT() {
+      if (!result)
+        return parser.emitError(loc, "unexpected keyword: ") << keyword;
+      return std::move(*result);
+    }
+
+  private:
+    /// The parser used to construct this switch.
+    AsmParser &parser;
+
+    /// The location of the keyword, used to emit errors as necessary.
+    SMLoc loc;
+
+    /// The parsed keyword itself.
+    StringRef keyword;
+
+    /// The result of the switch statement or none if currently unknown.
+    Optional<ResultT> result;
+  };
+
+  /// Parse a given keyword.
+  ParseResult parseKeyword(StringRef keyword) {
+    return parseKeyword(keyword, "");
+  }
+  virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
+
+  /// Parse a keyword into 'keyword'.
+  ParseResult parseKeyword(StringRef *keyword) {
+    auto loc = getCurrentLocation();
+    if (parseOptionalKeyword(keyword))
+      return emitError(loc, "expected valid keyword");
+    return success();
+  }
+
+  /// Parse the given keyword if present.
+  virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
+
+  /// Parse a keyword, if present, into 'keyword'.
+  virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
+
+  /// Parse a keyword, if present, and if one of the 'allowedValues',
+  /// into 'keyword'
+  virtual ParseResult
+  parseOptionalKeyword(StringRef *keyword,
+                       ArrayRef<StringRef> allowedValues) = 0;
+
+  /// Parse a keyword or a quoted string.
+  ParseResult parseKeywordOrString(std::string *result) {
+    if (failed(parseOptionalKeywordOrString(result)))
+      return emitError(getCurrentLocation())
+             << "expected valid keyword or string";
+    return success();
+  }
+
+  /// Parse an optional keyword or string.
+  virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
+
   //===--------------------------------------------------------------------===//
   // Attribute/Type Parsing
   //===--------------------------------------------------------------------===//
@@ -1124,6 +1196,17 @@ class AsmParser {
   virtual FailureOr<AsmDialectResourceHandle>
   parseResourceHandle(Dialect *dialect) = 0;
 
+  //===--------------------------------------------------------------------===//
+  // Code Completion
+  //===--------------------------------------------------------------------===//
+
+  /// Parse a keyword, or an empty string if the current location signals a code
+  /// completion.
+  virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0;
+
+  /// Signal the code completion of a set of expected tokens.
+  virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0;
+
 private:
   AsmParser(const AsmParser &) = delete;
   void operator=(const AsmParser &) = delete;

diff  --git a/mlir/include/mlir/Parser/CodeComplete.h b/mlir/include/mlir/Parser/CodeComplete.h
index 7dd7b6151aa69..1dcb2745d4bb5 100644
--- a/mlir/include/mlir/Parser/CodeComplete.h
+++ b/mlir/include/mlir/Parser/CodeComplete.h
@@ -10,9 +10,13 @@
 #define MLIR_PARSER_CODECOMPLETE_H
 
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringMap.h"
 #include "llvm/Support/SourceMgr.h"
 
 namespace mlir {
+class Attribute;
+class Type;
+
 /// This class provides an abstract interface into the parser for hooking in
 /// code completion events. This class is only really useful for providing
 /// language tooling for MLIR, general clients should not need to use this
@@ -28,8 +32,9 @@ class AsmParserCodeCompleteContext {
   // Completion Hooks
   //===--------------------------------------------------------------------===//
 
-  /// Signal code completion for a dialect name.
-  virtual void completeDialectName() = 0;
+  /// Signal code completion for a dialect name, with an optional prefix.
+  virtual void completeDialectName(StringRef prefix) = 0;
+  void completeDialectName() { completeDialectName(""); }
 
   /// Signal code completion for an operation name within the given dialect.
   virtual void completeOperationName(StringRef dialectName) = 0;
@@ -48,6 +53,16 @@ class AsmParserCodeCompleteContext {
   virtual void completeExpectedTokens(ArrayRef<StringRef> tokens,
                                       bool optional) = 0;
 
+  /// Signal a completion for an attribute.
+  virtual void completeAttribute(const llvm::StringMap<Attribute> &aliases) = 0;
+  virtual void completeDialectAttributeOrAlias(
+      const llvm::StringMap<Attribute> &aliases) = 0;
+
+  /// Signal a completion for a type.
+  virtual void completeType(const llvm::StringMap<Type> &aliases) = 0;
+  virtual void
+  completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) = 0;
+
 protected:
   /// Create a new code completion context with the given code complete
   /// location.

diff  --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
index 13e19d4276ebe..5c57f883fd3d9 100644
--- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
@@ -35,11 +35,9 @@ void PDLDialect::registerTypes() {
 
 static Type parsePDLType(AsmParser &parser) {
   StringRef typeTag;
-  if (parser.parseKeyword(&typeTag))
-    return Type();
   {
     Type genType;
-    auto parseResult = generatedTypeParser(parser, typeTag, genType);
+    auto parseResult = generatedTypeParser(parser, &typeTag, genType);
     if (parseResult.hasValue())
       return genType;
   }

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
index 4accfc173a592..90025e97c4e3d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
@@ -577,17 +577,11 @@ Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
 
   // Parse the kind keyword first.
   StringRef attrKind;
-  if (parser.parseKeyword(&attrKind))
-    return {};
-
   Attribute attr;
   OptionalParseResult result =
-      generatedAttributeParser(parser, attrKind, type, attr);
-  if (result.hasValue()) {
-    if (failed(result.getValue()))
-      return {};
+      generatedAttributeParser(parser, &attrKind, type, attr);
+  if (result.hasValue())
     return attr;
-  }
 
   if (attrKind == spirv::TargetEnvAttr::getKindName())
     return parseTargetEnvAttr(parser);

diff  --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h
index e15738e8ffdef..7e204bc392c77 100644
--- a/mlir/lib/Parser/AsmParserImpl.h
+++ b/mlir/lib/Parser/AsmParserImpl.h
@@ -242,6 +242,56 @@ class AsmParserImpl : public BaseT {
     return success();
   }
 
+  /// Parse a floating point value from the stream.
+  ParseResult parseFloat(double &result) override {
+    bool isNegative = parser.consumeIf(Token::minus);
+    Token curTok = parser.getToken();
+    SMLoc loc = curTok.getLoc();
+
+    // Check for a floating point value.
+    if (curTok.is(Token::floatliteral)) {
+      auto val = curTok.getFloatingPointValue();
+      if (!val)
+        return emitError(loc, "floating point value too large");
+      parser.consumeToken(Token::floatliteral);
+      result = isNegative ? -*val : *val;
+      return success();
+    }
+
+    // Check for a hexadecimal float value.
+    if (curTok.is(Token::integer)) {
+      Optional<APFloat> apResult;
+      if (failed(parser.parseFloatFromIntegerLiteral(
+              apResult, curTok, isNegative, APFloat::IEEEdouble(),
+              /*typeSizeInBits=*/64)))
+        return failure();
+
+      parser.consumeToken(Token::integer);
+      result = apResult->convertToDouble();
+      return success();
+    }
+
+    return emitError(loc, "expected floating point literal");
+  }
+
+  /// Parse an optional integer value from the stream.
+  OptionalParseResult parseOptionalInteger(APInt &result) override {
+    return parser.parseOptionalInteger(result);
+  }
+
+  /// Parse a list of comma-separated items with an optional delimiter.  If a
+  /// delimiter is provided, then an empty list is allowed.  If not, then at
+  /// least one element will be parsed.
+  ParseResult parseCommaSeparatedList(Delimiter delimiter,
+                                      function_ref<ParseResult()> parseElt,
+                                      StringRef contextMessage) override {
+    return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Keyword Parsing
+  //===--------------------------------------------------------------------===//
+
   ParseResult parseKeyword(StringRef keyword, const Twine &msg) override {
     if (parser.getToken().isCodeCompletion())
       return parser.codeCompleteExpectedTokens(keyword);
@@ -251,6 +301,7 @@ class AsmParserImpl : public BaseT {
       return emitError(loc, "expected '") << keyword << "'" << msg;
     return success();
   }
+  using AsmParser::parseKeyword;
 
   /// Parse the given keyword if present.
   ParseResult parseOptionalKeyword(StringRef keyword) override {
@@ -308,52 +359,6 @@ class AsmParserImpl : public BaseT {
     return parseOptionalString(result);
   }
 
-  /// Parse a floating point value from the stream.
-  ParseResult parseFloat(double &result) override {
-    bool isNegative = parser.consumeIf(Token::minus);
-    Token curTok = parser.getToken();
-    SMLoc loc = curTok.getLoc();
-
-    // Check for a floating point value.
-    if (curTok.is(Token::floatliteral)) {
-      auto val = curTok.getFloatingPointValue();
-      if (!val)
-        return emitError(loc, "floating point value too large");
-      parser.consumeToken(Token::floatliteral);
-      result = isNegative ? -*val : *val;
-      return success();
-    }
-
-    // Check for a hexadecimal float value.
-    if (curTok.is(Token::integer)) {
-      Optional<APFloat> apResult;
-      if (failed(parser.parseFloatFromIntegerLiteral(
-              apResult, curTok, isNegative, APFloat::IEEEdouble(),
-              /*typeSizeInBits=*/64)))
-        return failure();
-
-      parser.consumeToken(Token::integer);
-      result = apResult->convertToDouble();
-      return success();
-    }
-
-    return emitError(loc, "expected floating point literal");
-  }
-
-  /// Parse an optional integer value from the stream.
-  OptionalParseResult parseOptionalInteger(APInt &result) override {
-    return parser.parseOptionalInteger(result);
-  }
-
-  /// Parse a list of comma-separated items with an optional delimiter.  If a
-  /// delimiter is provided, then an empty list is allowed.  If not, then at
-  /// least one element will be parsed.
-  ParseResult parseCommaSeparatedList(Delimiter delimiter,
-                                      function_ref<ParseResult()> parseElt,
-                                      StringRef contextMessage) override {
-    return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
-  }
-
   //===--------------------------------------------------------------------===//
   // Attribute Parsing
   //===--------------------------------------------------------------------===//
@@ -528,6 +533,28 @@ class AsmParserImpl : public BaseT {
     return parser.parseXInDimensionList();
   }
 
+  //===--------------------------------------------------------------------===//
+  // Code Completion
+  //===--------------------------------------------------------------------===//
+
+  /// Parse a keyword, or an empty string if the current location signals a code
+  /// completion.
+  ParseResult parseKeywordOrCompletion(StringRef *keyword) override {
+    Token tok = parser.getToken();
+    if (tok.isCodeCompletion() && tok.getSpelling().empty()) {
+      *keyword = "";
+      return success();
+    }
+    return parseKeyword(keyword);
+  }
+
+  /// Signal the code completion of a set of expected tokens.
+  void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) override {
+    Token tok = parser.getToken();
+    if (tok.isCodeCompletion() && tok.getSpelling().empty())
+      (void)parser.codeCompleteExpectedTokens(tokens);
+  }
+
 protected:
   /// The source location of the dialect symbol.
   SMLoc nameLoc;

diff  --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 177420668a385..3de6a0d60c238 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -213,6 +213,12 @@ Attribute Parser::parseAttribute(Type type) {
     consumeToken(Token::kw_unit);
     return builder.getUnitAttr();
 
+    // Handle completion of an attribute.
+  case Token::code_complete:
+    if (getToken().isCodeCompletionFor(Token::hash_identifier))
+      return parseExtendedAttr(type);
+    return codeCompleteAttribute();
+
   default:
     // Parse a type attribute. We parse `Optional` here to allow for providing a
     // better error message.

diff  --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 4478d1936e661..2e488332c9596 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -43,9 +43,6 @@ class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
 };
 } // namespace
 
-/// Parse the body of a dialect symbol, which starts and ends with <>'s, and may
-/// be recursive. Return with the 'body' StringRef encompassing the entire
-/// body.
 ///
 ///   pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
 ///   pretty-dialect-sym-contents ::= pretty-dialect-sym-body
@@ -54,7 +51,8 @@ class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
 ///                                  | '{' pretty-dialect-sym-contents+ '}'
 ///                                  | '[^[<({>\])}\0]+'
 ///
-ParseResult Parser::parseDialectSymbolBody(StringRef &body) {
+ParseResult Parser::parseDialectSymbolBody(StringRef &body,
+                                           bool &isCodeCompletion) {
   // Symbol bodies are a relatively unstructured format that contains a series
   // of properly nested punctuation, with anything else in the middle. Scan
   // ahead to find it and consume it if successful, otherwise emit an error.
@@ -65,7 +63,16 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body) {
   // go until we find the matching '>' character.
   assert(*curPtr == '<');
   SmallVector<char, 8> nestedPunctuation;
+  const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
   do {
+    // Handle code completions, which may appear in the middle of the symbol
+    // body.
+    if (curPtr == codeCompleteLoc) {
+      isCodeCompletion = true;
+      nestedPunctuation.clear();
+      break;
+    }
+
     char c = *curPtr++;
     switch (c) {
     case '\0':
@@ -107,9 +114,19 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body) {
     case '"': {
       // Dispatch to the lexer to lex past strings.
       resetToken(curPtr - 1);
+      curPtr = state.curToken.getEndLoc().getPointer();
+
+      // Handle code completions, which may appear in the middle of the symbol
+      // body.
+      if (state.curToken.isCodeCompletion()) {
+        isCodeCompletion = true;
+        nestedPunctuation.clear();
+        break;
+      }
+
+      // Otherwise, ensure this token was actually a string.
       if (state.curToken.isNot(Token::string))
         return failure();
-      curPtr = state.curToken.getEndLoc().getPointer();
       break;
     }
 
@@ -129,19 +146,24 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body) {
 
 /// Parse an extended dialect symbol.
 template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
-static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
-                                  SymbolAliasMap &aliases,
+static Symbol parseExtendedSymbol(Parser &p, SymbolAliasMap &aliases,
                                   CreateFn &&createSymbol) {
+  Token tok = p.getToken();
+
+  // Handle code completion of the extended symbol.
+  StringRef identifier = tok.getSpelling().drop_front();
+  if (tok.isCodeCompletion() && identifier.empty())
+    return p.codeCompleteDialectSymbol(aliases);
+
   // Parse the dialect namespace.
-  StringRef identifier = p.getTokenSpelling().drop_front();
   SMLoc loc = p.getToken().getLoc();
-  p.consumeToken(identifierTok);
+  p.consumeToken();
 
   // Check to see if this is a pretty name.
   StringRef dialectName;
   StringRef symbolData;
   std::tie(dialectName, symbolData) = identifier.split('.');
-  bool isPrettyName = !symbolData.empty();
+  bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
 
   // Check to see if the symbol has trailing data, i.e. has an immediately
   // following '<'.
@@ -167,9 +189,17 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
   if (!isPrettyName) {
     // Point the symbol data to the end of the dialect name to start.
     symbolData = StringRef(dialectName.end(), 0);
-    if (p.parseDialectSymbolBody(symbolData))
+
+    // Parse the body of the symbol.
+    bool isCodeCompletion = false;
+    if (p.parseDialectSymbolBody(symbolData, isCodeCompletion))
       return nullptr;
-    symbolData = symbolData.drop_front().drop_back();
+    symbolData = symbolData.drop_front();
+
+    // If the body contained a code completion it won't have the trailing `>`
+    // token, so don't drop it.
+    if (!isCodeCompletion)
+      symbolData = symbolData.drop_back();
   } else {
     loc = SMLoc::getFromPointer(symbolData.data());
 
@@ -192,7 +222,7 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
 Attribute Parser::parseExtendedAttr(Type type) {
   MLIRContext *ctx = getContext();
   Attribute attr = parseExtendedSymbol<Attribute>(
-      *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions,
+      *this, state.symbols.attributeAliasDefinitions,
       [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
         // Parse an optional trailing colon type.
         Type attrType = type;
@@ -238,7 +268,7 @@ Attribute Parser::parseExtendedAttr(Type type) {
 Type Parser::parseExtendedType() {
   MLIRContext *ctx = getContext();
   return parseExtendedSymbol<Type>(
-      *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions,
+      *this, state.symbols.typeAliasDefinitions,
       [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
         // If we found a registered dialect, then ask it to parse the type.
         if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {

diff  --git a/mlir/lib/Parser/Lexer.h b/mlir/lib/Parser/Lexer.h
index e09ae168e3bba..10ef3bf6429b4 100644
--- a/mlir/lib/Parser/Lexer.h
+++ b/mlir/lib/Parser/Lexer.h
@@ -40,6 +40,10 @@ class Lexer {
   /// Returns the start of the buffer.
   const char *getBufferBegin() { return curBuffer.data(); }
 
+  /// Return the code completion location of the lexer, or nullptr if there is
+  /// none.
+  const char *getCodeCompleteLoc() const { return codeCompleteLoc; }
+
 private:
   // Helpers.
   Token formToken(Token::Kind kind, const char *tokStart) {

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index be4f886a9d0fb..e27f72c221714 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -404,6 +404,26 @@ ParseResult Parser::codeCompleteOptionalTokens(ArrayRef<StringRef> tokens) {
   return failure();
 }
 
+Attribute Parser::codeCompleteAttribute() {
+  state.codeCompleteContext->completeAttribute(
+      state.symbols.attributeAliasDefinitions);
+  return {};
+}
+Type Parser::codeCompleteType() {
+  state.codeCompleteContext->completeType(state.symbols.typeAliasDefinitions);
+  return {};
+}
+
+Attribute
+Parser::codeCompleteDialectSymbol(const llvm::StringMap<Attribute> &aliases) {
+  state.codeCompleteContext->completeDialectAttributeOrAlias(aliases);
+  return {};
+}
+Type Parser::codeCompleteDialectSymbol(const llvm::StringMap<Type> &aliases) {
+  state.codeCompleteContext->completeDialectTypeOrAlias(aliases);
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // OperationParser
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index 34a95a550e9d1..ce0d1b3b1b714 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -57,7 +57,16 @@ class Parser {
     return parseCommaSeparatedList(Delimiter::None, parseElementFn);
   }
 
-  ParseResult parseDialectSymbolBody(StringRef &body);
+  /// Parse the body of a dialect symbol, which starts and ends with <>'s, and
+  /// may be recursive. Return with the 'body' StringRef encompassing the entire
+  /// body. `isCodeCompletion` is set to true if the body contained a code
+  /// completion location, in which case the body is only populated up to the
+  /// completion.
+  ParseResult parseDialectSymbolBody(StringRef &body, bool &isCodeCompletion);
+  ParseResult parseDialectSymbolBody(StringRef &body) {
+    bool isCodeCompletion = false;
+    return parseDialectSymbolBody(body, isCodeCompletion);
+  }
 
   // We have two forms of parsing methods - those that return a non-null
   // pointer on success, and those that return a ParseResult to indicate whether
@@ -322,6 +331,12 @@ class Parser {
   ParseResult codeCompleteExpectedTokens(ArrayRef<StringRef> tokens);
   ParseResult codeCompleteOptionalTokens(ArrayRef<StringRef> tokens);
 
+  Attribute codeCompleteAttribute();
+  Type codeCompleteType();
+  Attribute
+  codeCompleteDialectSymbol(const llvm::StringMap<Attribute> &aliases);
+  Type codeCompleteDialectSymbol(const llvm::StringMap<Type> &aliases);
+
 protected:
   /// The Parser is subclassed and reinstantiated.  Do not add additional
   /// non-trivial state here, add it to the ParserState class.

diff  --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 660dfe1cbea28..a916ec72e5f4d 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -358,6 +358,12 @@ Type Parser::parseNonFunctionType() {
   // extended type
   case Token::exclamation_identifier:
     return parseExtendedType();
+
+  // Handle completion of a dialect type.
+  case Token::code_complete:
+    if (getToken().isCodeCompletionFor(Token::exclamation_identifier))
+      return parseExtendedType();
+    return codeCompleteType();
   }
 }
 

diff  --git a/mlir/lib/Tools/lsp-server-support/Protocol.h b/mlir/lib/Tools/lsp-server-support/Protocol.h
index 03fe3ca6e41e7..0fc30d5bcced2 100644
--- a/mlir/lib/Tools/lsp-server-support/Protocol.h
+++ b/mlir/lib/Tools/lsp-server-support/Protocol.h
@@ -781,8 +781,9 @@ enum class InsertTextFormat {
 
 struct CompletionItem {
   CompletionItem() = default;
-  CompletionItem(StringRef label, CompletionItemKind kind)
-      : label(label.str()), kind(kind),
+  CompletionItem(const Twine &label, CompletionItemKind kind,
+                 StringRef sortText = "")
+      : label(label.str()), kind(kind), sortText(sortText.str()),
         insertTextFormat(InsertTextFormat::PlainText) {}
 
   /// The label of this completion item. By default also the text that is

diff  --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
index 9d3d51b48b2e1..a2f4f6f8d36d1 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
@@ -636,15 +636,17 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
       : AsmParserCodeCompleteContext(completeLoc),
         completionList(completionList), ctx(ctx) {}
 
-  /// Signal code completion for a dialect name.
-  void completeDialectName() final {
+  /// Signal code completion for a dialect name, with an optional prefix.
+  void completeDialectName(StringRef prefix) final {
     for (StringRef dialect : ctx->getAvailableDialects()) {
-      lsp::CompletionItem item(dialect, lsp::CompletionItemKind::Module);
-      item.sortText = "2";
+      lsp::CompletionItem item(prefix + dialect,
+                               lsp::CompletionItemKind::Module,
+                               /*sortText=*/"3");
       item.detail = "dialect";
       completionList.items.emplace_back(item);
     }
   }
+  using AsmParserCodeCompleteContext::completeDialectName;
 
   /// Signal code completion for an operation name within the given dialect.
   void completeOperationName(StringRef dialectName) final {
@@ -658,8 +660,8 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
 
       lsp::CompletionItem item(
           op.getStringRef().drop_front(dialectName.size() + 1),
-          lsp::CompletionItemKind::Field);
-      item.sortText = "1";
+          lsp::CompletionItemKind::Field,
+          /*sortText=*/"1");
       item.detail = "operation";
       completionList.items.emplace_back(item);
     }
@@ -693,13 +695,71 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
   /// Signal a completion for the given expected token.
   void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
     for (StringRef token : tokens) {
-      lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword);
-      item.sortText = "0";
+      lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword,
+                               /*sortText=*/"0");
       item.detail = optional ? "optional" : "";
       completionList.items.emplace_back(item);
     }
   }
 
+  /// Signal a completion for an attribute.
+  void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
+    appendSimpleCompletions({"affine_set", "affine_map", "dense", "false",
+                             "loc", "opaque", "sparse", "true", "unit"},
+                            lsp::CompletionItemKind::Field,
+                            /*sortText=*/"1");
+
+    completeDialectName("#");
+    completeAliases(aliases, "#");
+  }
+  void completeDialectAttributeOrAlias(
+      const llvm::StringMap<Attribute> &aliases) override {
+    completeDialectName();
+    completeAliases(aliases);
+  }
+
+  /// Signal a completion for a type.
+  void completeType(const llvm::StringMap<Type> &aliases) override {
+    appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
+                             "bf16", "f16", "f32", "f64", "f80", "f128",
+                             "index", "none"},
+                            lsp::CompletionItemKind::Field,
+                            /*sortText=*/"1");
+    lsp::CompletionItem item("i<N>", lsp::CompletionItemKind::Field,
+                             /*sortText=*/"1");
+    item.insertText = "i";
+    completionList.items.emplace_back(item);
+
+    completeDialectName("!");
+    completeAliases(aliases, "!");
+  }
+  void
+  completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
+    completeDialectName();
+    completeAliases(aliases);
+  }
+
+  /// Add completion results for the given set of aliases.
+  template <typename T>
+  void completeAliases(const llvm::StringMap<T> &aliases,
+                       StringRef prefix = "") {
+    for (const auto &alias : aliases) {
+      lsp::CompletionItem item(prefix + alias.getKey(),
+                               lsp::CompletionItemKind::Field,
+                               /*sortText=*/"2");
+      llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
+      completionList.items.emplace_back(item);
+    }
+  }
+
+  /// Add a set of simple completions that all have the same kind.
+  void appendSimpleCompletions(ArrayRef<StringRef> completions,
+                               lsp::CompletionItemKind kind,
+                               StringRef sortText = "") {
+    for (StringRef completion : completions)
+      completionList.items.emplace_back(completion, kind, sortText);
+  }
+
 private:
   lsp::CompletionList &completionList;
   MLIRContext *ctx;

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 5418025d501c9..df56055294385 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -408,12 +408,9 @@ void TestDialect::registerTypes() {
 Type TestDialect::parseTestType(AsmParser &parser,
                                 SetVector<Type> &stack) const {
   StringRef typeTag;
-  if (failed(parser.parseKeyword(&typeTag)))
-    return Type();
-
   {
     Type genType;
-    auto parseResult = generatedTypeParser(parser, typeTag, genType);
+    auto parseResult = generatedTypeParser(parser, &typeTag, genType);
     if (parseResult.hasValue())
       return genType;
   }

diff  --git a/mlir/test/mlir-lsp-server/completion.test b/mlir/test/mlir-lsp-server/completion.test
index 4264fe69bc737..fe067b256f7ac 100644
--- a/mlir/test/mlir-lsp-server/completion.test
+++ b/mlir/test/mlir-lsp-server/completion.test
@@ -5,14 +5,14 @@
   "uri":"test:///foo.mlir",
   "languageId":"mlir",
   "version":1,
-  "text":"func.func private @foo(%arg: i32) -> i32 {\n%cast = \"builtin.unrealized_conversion_cast\"() : () -> (i32)\nreturn %"
+  "text":"#attr = i32\n!alias = i32\nfunc.func private @foo(%arg: i32) -> i32 {\n%cast = \"builtin.unrealized_conversion_cast\"() : () -> (!pdl.value)\nreturn %"
 }}}
 // -----
 {"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
   "textDocument":{"uri":"test:///foo.mlir"},
-  "position":{"line":0,"character":0}
+  "position":{"line":2,"character":0}
 }}
-//      CHECK:  "id": 1
+// CHECK-LABEL: "id": 1
 // CHECK-NEXT:  "jsonrpc": "2.0",
 // CHECK-NEXT:  "result": {
 // CHECK-NEXT:    "isIncomplete": false,
@@ -22,7 +22,7 @@
 // CHECK:             "insertTextFormat": 1,
 // CHECK:             "kind": 9,
 // CHECK:             "label": "builtin",
-// CHECK:             "sortText": "2"
+// CHECK:             "sortText": "3"
 // CHECK:           },
 // CHECK:           {
 // CHECK:             "detail": "operation",
@@ -34,11 +34,11 @@
 // CHECK:         ]
 // CHECK-NEXT:  }
 // -----
-{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+{"jsonrpc":"2.0","id":2,"method":"textDocument/completion","params":{
   "textDocument":{"uri":"test:///foo.mlir"},
-  "position":{"line":1,"character":9}
+  "position":{"line":3,"character":9}
 }}
-//      CHECK:  "id": 1
+// CHECK-LABEL: "id": 2
 // CHECK-NEXT:  "jsonrpc": "2.0",
 // CHECK-NEXT:  "result": {
 // CHECK-NEXT:    "isIncomplete": false,
@@ -48,17 +48,17 @@
 // CHECK:             "insertTextFormat": 1,
 // CHECK:             "kind": 9,
 // CHECK:             "label": "builtin",
-// CHECK:             "sortText": "2"
+// CHECK:             "sortText": "3"
 // CHECK:           },
 // CHECK-NOT:       "detail": "operation",
 // CHECK:         ]
 // CHECK-NEXT:  }
 // -----
-{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+{"jsonrpc":"2.0","id":3,"method":"textDocument/completion","params":{
   "textDocument":{"uri":"test:///foo.mlir"},
-  "position":{"line":1,"character":17}
+  "position":{"line":3,"character":17}
 }}
-//      CHECK:  "id": 1
+// CHECK-LABEL: "id": 3
 // CHECK-NEXT:  "jsonrpc": "2.0",
 // CHECK-NEXT:  "result": {
 // CHECK-NEXT:    "isIncomplete": false,
@@ -74,17 +74,17 @@
 // CHECK:         ]
 // CHECK-NEXT:  }
 // -----
-{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+{"jsonrpc":"2.0","id":4,"method":"textDocument/completion","params":{
   "textDocument":{"uri":"test:///foo.mlir"},
-  "position":{"line":2,"character":8}
+  "position":{"line":4,"character":8}
 }}
-//      CHECK:  "id": 1
+// CHECK-LABEL: "id": 4
 // CHECK-NEXT:  "jsonrpc": "2.0",
 // CHECK-NEXT:  "result": {
 // CHECK-NEXT:    "isIncomplete": false,
 // CHECK-NEXT:    "items": [
 // CHECK-NEXT:      {
-// CHECK-NEXT:        "detail": "builtin.unrealized_conversion_cast: i32",
+// CHECK-NEXT:        "detail": "builtin.unrealized_conversion_cast: !pdl.value",
 // CHECK-NEXT:        "insertText": "cast",
 // CHECK-NEXT:        "insertTextFormat": 1,
 // CHECK-NEXT:        "kind": 6,
@@ -100,11 +100,11 @@
 // CHECK:         ]
 // CHECK-NEXT:  }
 // -----
-{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+{"jsonrpc":"2.0","id":5,"method":"textDocument/completion","params":{
   "textDocument":{"uri":"test:///foo.mlir"},
-  "position":{"line":0,"character":10}
+  "position":{"line":2,"character":10}
 }}
-//      CHECK:  "id": 1
+// CHECK-LABEL: "id": 5
 // CHECK-NEXT:  "jsonrpc": "2.0",
 // CHECK-NEXT:  "result": {
 // CHECK-NEXT:    "isIncomplete": false,
@@ -133,6 +133,134 @@
 // CHECK-NEXT:    ]
 // CHECK-NEXT:  }
 // -----
-{"jsonrpc":"2.0","id":3,"method":"shutdown"}
+{"jsonrpc":"2.0","id":6,"method":"textDocument/completion","params":{
+  "textDocument":{"uri":"test:///foo.mlir"},
+  "position":{"line":0,"character":8}
+}}
+// CHECK-LABEL: "id": 6
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "isIncomplete": false,
+// CHECK-NEXT:    "items": [
+// CHECK:           {
+// CHECK:             "insertTextFormat": 1,
+// CHECK:             "kind": 5,
+// CHECK:             "label": "false"
+// CHECK:           },
+// CHECK:           {
+// CHECK:             "insertTextFormat": 1,
+// CHECK:             "kind": 5,
+// CHECK:             "label": "loc"
+// CHECK:           },
+// CHECK:           {
+// CHECK:             "insertTextFormat": 1,
+// CHECK:             "kind": 5,
+// CHECK:             "label": "true"
+// CHECK:           },
+// CHECK:           {
+// CHECK:             "insertTextFormat": 1,
+// CHECK:             "kind": 5,
+// CHECK:             "label": "unit"
+// CHECK:           }
+// CHECK:    ]
+// CHECK:  }
+// -----
+{"jsonrpc":"2.0","id":7,"method":"textDocument/completion","params":{
+  "textDocument":{"uri":"test:///foo.mlir"},
+  "position":{"line":3,"character":56}
+}}
+// CHECK-LABEL: "id": 7
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "isIncomplete": false,
+// CHECK-NEXT:    "items": [
+// CHECK:           {
+// CHECK:              "insertTextFormat": 1,
+// CHECK:              "kind": 5,
+// CHECK:             "label": "index"
+// CHECK:           },
+// CHECK:           {
+// CHECK:             "insertTextFormat": 1,
+// CHECK:             "kind": 5,
+// CHECK:             "label": "none"
+// CHECK:           },
+// CHECK:           {
+// CHECK:             "insertText": "i",
+// CHECK:             "insertTextFormat": 1,
+// CHECK:             "kind": 5,
+// CHECK:             "label": "i<N>"
+// CHECK:           }
+// CHECK:         ]
+// CHECK-NEXT:  }
+// -----
+{"jsonrpc":"2.0","id":8,"method":"textDocument/completion","params":{
+  "textDocument":{"uri":"test:///foo.mlir"},
+  "position":{"line":3,"character":57}
+}}
+// CHECK-LABEL: "id": 8
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "isIncomplete": false,
+// CHECK-NEXT:    "items": [
+// CHECK:           {
+// CHECK:             "detail": "dialect",
+// CHECK:             "insertTextFormat": 1,
+// CHECK:             "kind": 9,
+// CHECK:             "label": "builtin",
+// CHECK:             "sortText": "3"
+// CHECK:           },
+// CHECK:           {
+// CHECK:             "detail": "alias: i32",
+// CHECK:             "insertTextFormat": 1,
+// CHECK:             "kind": 5,
+// CHECK:             "label": "alias",
+// CHECK:             "sortText": "2"
+// CHECK:           }
+// CHECK:         ]
+// CHECK-NEXT:  }
+// -----
+{"jsonrpc":"2.0","id":9,"method":"textDocument/completion","params":{
+  "textDocument":{"uri":"test:///foo.mlir"},
+  "position":{"line":3,"character":61}
+}}
+// CHECK-LABEL: "id": 9
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "isIncomplete": false,
+// CHECK-NEXT:    "items": [
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 14,
+// CHECK-NEXT:        "label": "attribute",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 14,
+// CHECK-NEXT:        "label": "operation",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 14,
+// CHECK-NEXT:        "label": "range",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 14,
+// CHECK-NEXT:        "label": "type",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 14,
+// CHECK-NEXT:        "label": "value",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      }
+// CHECK-NEXT:    ]
+// CHECK-NEXT:  }
+// -----
+{"jsonrpc":"2.0","id":10,"method":"shutdown"}
 // -----
 {"jsonrpc":"2.0","method":"exit"}

diff  --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 3737ce77948e2..50dba45c63d0b 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -21,16 +21,19 @@ include "mlir/IR/OpBase.td"
 
 // DEF-LABEL: ::mlir::OptionalParseResult generatedAttributeParser(
 // DEF-SAME: ::mlir::AsmParser &parser,
-// DEF-SAME: ::llvm::StringRef mnemonic, ::mlir::Type type,
+// DEF-SAME: ::llvm::StringRef *mnemonic, ::mlir::Type type,
 // DEF-SAME: ::mlir::Attribute &value) {
-// DEF: if (mnemonic == ::test::CompoundAAttr::getMnemonic()) {
+// DEF: return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
+// DEF: .Case(::test::CompoundAAttr::getMnemonic()
 // DEF-NEXT: value = ::test::CompoundAAttr::parse(parser, type);
 // DEF-NEXT: return ::mlir::success(!!value);
-// DEF-NEXT: }
-// DEF-NEXT: if (mnemonic == ::test::IndexAttr::getMnemonic()) {
+// DEF-NEXT: })
+// DEF-NEXT: .Case(::test::IndexAttr::getMnemonic()
 // DEF-NEXT:   value = ::test::IndexAttr::parse(parser, type);
 // DEF-NEXT:   return ::mlir::success(!!value);
-// DEF: return {};
+// DEF: .Default([&](llvm::StringRef keyword, 
+// DEF-NEXT:   *mnemonic = keyword;
+// DEF-NEXT:   return llvm::None; 
 
 def Test_Dialect: Dialect {
 // DECL-NOT: TestDialect

diff  --git a/mlir/test/mlir-tblgen/default-type-attr-print-parser.td b/mlir/test/mlir-tblgen/default-type-attr-print-parser.td
index ac898a97d9eee..1c479f11abbf6 100644
--- a/mlir/test/mlir-tblgen/default-type-attr-print-parser.td
+++ b/mlir/test/mlir-tblgen/default-type-attr-print-parser.td
@@ -27,11 +27,9 @@ def AttrA : TestAttr<"AttrA"> {
 // ATTR:                                               ::mlir::Type type) const {
 // ATTR:   ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
 // ATTR:   ::llvm::StringRef attrTag;
-// ATTR:   if (::mlir::failed(parser.parseKeyword(&attrTag)))
-// ATTR:     return {};
 // ATTR:   {
 // ATTR:     ::mlir::Attribute attr;
-// ATTR:     auto parseResult = generatedAttributeParser(parser, attrTag, type, attr);
+// ATTR:     auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
 // ATTR:     if (parseResult.hasValue())
 // ATTR:       return attr;
 // ATTR:   }
@@ -57,10 +55,8 @@ def TypeA : TestType<"TypeA"> {
 // TYPE: ::mlir::Type TestDialect::parseType(::mlir::DialectAsmParser &parser) const {
 // TYPE:   ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
 // TYPE:   ::llvm::StringRef mnemonic;
-// TYPE:   if (parser.parseKeyword(&mnemonic))
-// TYPE:     return ::mlir::Type();
 // TYPE:   ::mlir::Type genType;
-// TYPE:   auto parseResult = generatedTypeParser(parser, mnemonic, genType);
+// TYPE:   auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
 // TYPE:   if (parseResult.hasValue())
 // TYPE:     return genType;
 // TYPE:   parser.emitError(typeLoc) << "unknown  type `"

diff  --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index 833195530f9b0..29ebad5439c08 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -22,16 +22,18 @@ include "mlir/IR/OpBase.td"
 
 // DEF-LABEL: ::mlir::OptionalParseResult generatedTypeParser(
 // DEF-SAME: ::mlir::AsmParser &parser,
-// DEF-SAME: ::llvm::StringRef mnemonic,
+// DEF-SAME: ::llvm::StringRef *mnemonic,
 // DEF-SAME: ::mlir::Type &value) {
-// DEF: if (mnemonic == ::test::CompoundAType::getMnemonic()) {
+// DEF: .Case(::test::CompoundAType::getMnemonic()
 // DEF-NEXT:   value = ::test::CompoundAType::parse(parser);
 // DEF-NEXT:   return ::mlir::success(!!value);
-// DEF-NEXT: }
-// DEF-NEXT: if (mnemonic == ::test::IndexType::getMnemonic()) {
+// DEF-NEXT: })
+// DEF-NEXT: .Case(::test::IndexType::getMnemonic()
 // DEF-NEXT:   value = ::test::IndexType::parse(parser);
 // DEF-NEXT:   return ::mlir::success(!!value);
-// DEF: return {};
+// DEF: .Default([&](llvm::StringRef keyword, 
+// DEF-NEXT:   *mnemonic = keyword;
+// DEF-NEXT:   return llvm::None; 
 
 def Test_Dialect: Dialect {
 // DECL-NOT: TestDialect

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 759143dc606b0..db443e9ac0cff 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -673,11 +673,9 @@ ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
                                       ::mlir::Type type) const {{
   ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
   ::llvm::StringRef attrTag;
-  if (::mlir::failed(parser.parseKeyword(&attrTag)))
-    return {{};
   {{
     ::mlir::Attribute attr;
-    auto parseResult = generatedAttributeParser(parser, attrTag, type, attr);
+    auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
     if (parseResult.hasValue())
       return attr;
   }
@@ -723,10 +721,8 @@ static const char *const dialectDefaultTypePrinterParserDispatch = R"(
 ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
   ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
   ::llvm::StringRef mnemonic;
-  if (parser.parseKeyword(&mnemonic))
-    return ::mlir::Type();
   ::mlir::Type genType;
-  auto parseResult = generatedTypeParser(parser, mnemonic, genType);
+  auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
   if (parseResult.hasValue())
     return genType;
   {1}
@@ -771,7 +767,7 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
   }
   // Declare the parser.
   SmallVector<MethodParameter> params = {{"::mlir::AsmParser &", "parser"},
-                                         {"::llvm::StringRef", "mnemonic"}};
+                                         {"::llvm::StringRef *", "mnemonic"}};
   if (isAttrGenerator)
     params.emplace_back("::mlir::Type", "type");
   params.emplace_back(strfmt("::mlir::{0} &", valueType), "value");
@@ -784,14 +780,18 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
                  {{strfmt("::mlir::{0}", valueType), "def"},
                   {"::mlir::AsmPrinter &", "printer"}});
 
-  // The parser dispatch is just a list of if-elses, matching on the mnemonic
-  // and calling the def's parse function.
+  // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and
+  // calling the def's parse function.
+  parse.body() << "  return "
+                  "::mlir::AsmParser::KeywordSwitch<::mlir::"
+                  "OptionalParseResult>(parser)\n";
   const char *const getValueForMnemonic =
-      R"(  if (mnemonic == {0}::getMnemonic()) {{
-    value = {0}::{1};
-    return ::mlir::success(!!value);
-  }
+      R"(    .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{
+      value = {0}::{1};
+      return ::mlir::success(!!value);
+    })
 )";
+
   // The printer dispatch uses llvm::TypeSwitch to find and call the correct
   // printer.
   printer.body() << "  return ::llvm::TypeSwitch<::mlir::" << valueType
@@ -822,7 +822,10 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
       printDef = "\nt.print(printer);";
     printer.body() << llvm::formatv(printValue, defClass, printDef);
   }
-  parse.body() << "  return {};";
+  parse.body() << "    .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n"
+                  "      *mnemonic = keyword;\n"
+                  "      return llvm::None;\n"
+                  "    });";
   printer.body() << "    .Default([](auto) { return ::mlir::failure(); });";
 
   raw_indented_ostream indentedOs(os);


        


More information about the Mlir-commits mailing list