[Mlir-commits] [mlir] fe4f512 - [mlir:LSP] Add support for code completing attributes and types
River Riddle
llvmlistbot at llvm.org
Fri Jul 8 16:56:29 PDT 2022
Author: River Riddle
Date: 2022-07-08T16:24:55-07:00
New Revision: fe4f512be7a57ea7bbaa36b5261c2fa00e306cf9
URL: https://github.com/llvm/llvm-project/commit/fe4f512be7a57ea7bbaa36b5261c2fa00e306cf9
DIFF: https://github.com/llvm/llvm-project/commit/fe4f512be7a57ea7bbaa36b5261c2fa00e306cf9.diff
LOG: [mlir:LSP] Add support for code completing attributes and types
This required changing a bit of how attributes/types are parsed. A new
`KeywordSwitch` class was added to AsmParser that provides a StringSwitch
like API for parsing keywords with a set of potential matches. It intends to
both provide a cleaner API, and enable injection for code completion. This
required changing the API of `generated(Attr|Type)Parser` to handle the
parsing of the keyword, instead of having the user do it. Most upstream
dialects use the autogenerated handling and didn't require a direct update.
Differential Revision: https://reviews.llvm.org/D129267
Added:
Modified:
flang/lib/Optimizer/Dialect/FIRType.cpp
mlir/docs/AttributesAndTypes.md
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/Parser/CodeComplete.h
mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
mlir/lib/Parser/AsmParserImpl.h
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/DialectSymbolParser.cpp
mlir/lib/Parser/Lexer.h
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/Parser.h
mlir/lib/Parser/TypeParser.cpp
mlir/lib/Tools/lsp-server-support/Protocol.h
mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/test/mlir-lsp-server/completion.test
mlir/test/mlir-tblgen/attrdefs.td
mlir/test/mlir-tblgen/default-type-attr-print-parser.td
mlir/test/mlir-tblgen/typedefs.td
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index 0a8a2eb0385b7..77075395576da 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -116,10 +116,8 @@ RecordType verifyDerived(mlir::AsmParser &parser, RecordType derivedTy,
mlir::Type fir::parseFirType(FIROpsDialect *dialect,
mlir::DialectAsmParser &parser) {
mlir::StringRef typeTag;
- if (parser.parseKeyword(&typeTag))
- return {};
mlir::Type genType;
- auto parseResult = generatedTypeParser(parser, typeTag, genType);
+ auto parseResult = generatedTypeParser(parser, &typeTag, genType);
if (parseResult.hasValue())
return genType;
parser.emitError(parser.getNameLoc(), "unknown fir type: ") << typeTag;
diff --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md
index 36c134ee9435b..e1ad13952fe0b 100644
--- a/mlir/docs/AttributesAndTypes.md
+++ b/mlir/docs/AttributesAndTypes.md
@@ -473,10 +473,10 @@ one for printing. These static functions placed alongside the class definitions
and have the following function signatures:
```c++
-static ParseResult generatedAttributeParser(DialectAsmParser& parser, StringRef mnemonic, Type attrType, Attribute &result);
+static ParseResult generatedAttributeParser(DialectAsmParser& parser, StringRef *mnemonic, Type attrType, Attribute &result);
static LogicalResult generatedAttributePrinter(Attribute attr, DialectAsmPrinter& printer);
-static ParseResult generatedTypeParser(DialectAsmParser& parser, StringRef mnemonic, Type &result);
+static ParseResult generatedTypeParser(DialectAsmParser& parser, StringRef *mnemonic, Type &result);
static LogicalResult generatedTypePrinter(Type type, DialectAsmPrinter& printer);
```
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 33967b0e42943..cb240a40e0c79 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -571,43 +571,6 @@ class AsmParser {
/// Parse a quoted string token if present.
virtual ParseResult parseOptionalString(std::string *string) = 0;
- /// Parse a given keyword.
- ParseResult parseKeyword(StringRef keyword) {
- return parseKeyword(keyword, "");
- }
- virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
-
- /// Parse a keyword into 'keyword'.
- ParseResult parseKeyword(StringRef *keyword) {
- auto loc = getCurrentLocation();
- if (parseOptionalKeyword(keyword))
- return emitError(loc, "expected valid keyword");
- return success();
- }
-
- /// Parse the given keyword if present.
- virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
-
- /// Parse a keyword, if present, into 'keyword'.
- virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
-
- /// Parse a keyword, if present, and if one of the 'allowedValues',
- /// into 'keyword'
- virtual ParseResult
- parseOptionalKeyword(StringRef *keyword,
- ArrayRef<StringRef> allowedValues) = 0;
-
- /// Parse a keyword or a quoted string.
- ParseResult parseKeywordOrString(std::string *result) {
- if (failed(parseOptionalKeywordOrString(result)))
- return emitError(getCurrentLocation())
- << "expected valid keyword or string";
- return success();
- }
-
- /// Parse an optional keyword or string.
- virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
-
/// Parse a `(` token.
virtual ParseResult parseLParen() = 0;
@@ -712,6 +675,115 @@ class AsmParser {
return parseCommaSeparatedList(Delimiter::None, parseElementFn);
}
+ //===--------------------------------------------------------------------===//
+ // Keyword Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// This class represents a StringSwitch like class that is useful for parsing
+ /// expected keywords. On construction, it invokes `parseKeyword` and
+ /// processes each of the provided cases statements until a match is hit. The
+ /// provided `ResultT` must be assignable from `failure()`.
+ template <typename ResultT = ParseResult>
+ class KeywordSwitch {
+ public:
+ KeywordSwitch(AsmParser &parser)
+ : parser(parser), loc(parser.getCurrentLocation()) {
+ if (failed(parser.parseKeywordOrCompletion(&keyword)))
+ result = failure();
+ }
+
+ /// Case that uses the provided value when true.
+ KeywordSwitch &Case(StringLiteral str, ResultT value) {
+ return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
+ }
+ KeywordSwitch &Default(ResultT value) {
+ return Default([&](StringRef, SMLoc) { return std::move(value); });
+ }
+ /// Case that invokes the provided functor when true. The parameters passed
+ /// to the functor are the keyword, and the location of the keyword (in case
+ /// any errors need to be emitted).
+ template <typename FnT>
+ std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
+ Case(StringLiteral str, FnT &&fn) {
+ if (result)
+ return *this;
+
+ // If the word was empty, record this as a completion.
+ if (keyword.empty())
+ parser.codeCompleteExpectedTokens(str);
+ else if (keyword == str)
+ result.emplace(std::move(fn(keyword, loc)));
+ return *this;
+ }
+ template <typename FnT>
+ std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
+ Default(FnT &&fn) {
+ if (!result)
+ result.emplace(fn(keyword, loc));
+ return *this;
+ }
+
+ /// Returns true if this switch has a value yet.
+ bool hasValue() const { return result.hasValue(); }
+
+ /// Return the result of the switch.
+ LLVM_NODISCARD operator ResultT() {
+ if (!result)
+ return parser.emitError(loc, "unexpected keyword: ") << keyword;
+ return std::move(*result);
+ }
+
+ private:
+ /// The parser used to construct this switch.
+ AsmParser &parser;
+
+ /// The location of the keyword, used to emit errors as necessary.
+ SMLoc loc;
+
+ /// The parsed keyword itself.
+ StringRef keyword;
+
+ /// The result of the switch statement or none if currently unknown.
+ Optional<ResultT> result;
+ };
+
+ /// Parse a given keyword.
+ ParseResult parseKeyword(StringRef keyword) {
+ return parseKeyword(keyword, "");
+ }
+ virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
+
+ /// Parse a keyword into 'keyword'.
+ ParseResult parseKeyword(StringRef *keyword) {
+ auto loc = getCurrentLocation();
+ if (parseOptionalKeyword(keyword))
+ return emitError(loc, "expected valid keyword");
+ return success();
+ }
+
+ /// Parse the given keyword if present.
+ virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
+
+ /// Parse a keyword, if present, into 'keyword'.
+ virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
+
+ /// Parse a keyword, if present, and if one of the 'allowedValues',
+ /// into 'keyword'
+ virtual ParseResult
+ parseOptionalKeyword(StringRef *keyword,
+ ArrayRef<StringRef> allowedValues) = 0;
+
+ /// Parse a keyword or a quoted string.
+ ParseResult parseKeywordOrString(std::string *result) {
+ if (failed(parseOptionalKeywordOrString(result)))
+ return emitError(getCurrentLocation())
+ << "expected valid keyword or string";
+ return success();
+ }
+
+ /// Parse an optional keyword or string.
+ virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
+
//===--------------------------------------------------------------------===//
// Attribute/Type Parsing
//===--------------------------------------------------------------------===//
@@ -1124,6 +1196,17 @@ class AsmParser {
virtual FailureOr<AsmDialectResourceHandle>
parseResourceHandle(Dialect *dialect) = 0;
+ //===--------------------------------------------------------------------===//
+ // Code Completion
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a keyword, or an empty string if the current location signals a code
+ /// completion.
+ virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0;
+
+ /// Signal the code completion of a set of expected tokens.
+ virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0;
+
private:
AsmParser(const AsmParser &) = delete;
void operator=(const AsmParser &) = delete;
diff --git a/mlir/include/mlir/Parser/CodeComplete.h b/mlir/include/mlir/Parser/CodeComplete.h
index 7dd7b6151aa69..1dcb2745d4bb5 100644
--- a/mlir/include/mlir/Parser/CodeComplete.h
+++ b/mlir/include/mlir/Parser/CodeComplete.h
@@ -10,9 +10,13 @@
#define MLIR_PARSER_CODECOMPLETE_H
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringMap.h"
#include "llvm/Support/SourceMgr.h"
namespace mlir {
+class Attribute;
+class Type;
+
/// This class provides an abstract interface into the parser for hooking in
/// code completion events. This class is only really useful for providing
/// language tooling for MLIR, general clients should not need to use this
@@ -28,8 +32,9 @@ class AsmParserCodeCompleteContext {
// Completion Hooks
//===--------------------------------------------------------------------===//
- /// Signal code completion for a dialect name.
- virtual void completeDialectName() = 0;
+ /// Signal code completion for a dialect name, with an optional prefix.
+ virtual void completeDialectName(StringRef prefix) = 0;
+ void completeDialectName() { completeDialectName(""); }
/// Signal code completion for an operation name within the given dialect.
virtual void completeOperationName(StringRef dialectName) = 0;
@@ -48,6 +53,16 @@ class AsmParserCodeCompleteContext {
virtual void completeExpectedTokens(ArrayRef<StringRef> tokens,
bool optional) = 0;
+ /// Signal a completion for an attribute.
+ virtual void completeAttribute(const llvm::StringMap<Attribute> &aliases) = 0;
+ virtual void completeDialectAttributeOrAlias(
+ const llvm::StringMap<Attribute> &aliases) = 0;
+
+ /// Signal a completion for a type.
+ virtual void completeType(const llvm::StringMap<Type> &aliases) = 0;
+ virtual void
+ completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) = 0;
+
protected:
/// Create a new code completion context with the given code complete
/// location.
diff --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
index 13e19d4276ebe..5c57f883fd3d9 100644
--- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
@@ -35,11 +35,9 @@ void PDLDialect::registerTypes() {
static Type parsePDLType(AsmParser &parser) {
StringRef typeTag;
- if (parser.parseKeyword(&typeTag))
- return Type();
{
Type genType;
- auto parseResult = generatedTypeParser(parser, typeTag, genType);
+ auto parseResult = generatedTypeParser(parser, &typeTag, genType);
if (parseResult.hasValue())
return genType;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
index 4accfc173a592..90025e97c4e3d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
@@ -577,17 +577,11 @@ Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
// Parse the kind keyword first.
StringRef attrKind;
- if (parser.parseKeyword(&attrKind))
- return {};
-
Attribute attr;
OptionalParseResult result =
- generatedAttributeParser(parser, attrKind, type, attr);
- if (result.hasValue()) {
- if (failed(result.getValue()))
- return {};
+ generatedAttributeParser(parser, &attrKind, type, attr);
+ if (result.hasValue())
return attr;
- }
if (attrKind == spirv::TargetEnvAttr::getKindName())
return parseTargetEnvAttr(parser);
diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h
index e15738e8ffdef..7e204bc392c77 100644
--- a/mlir/lib/Parser/AsmParserImpl.h
+++ b/mlir/lib/Parser/AsmParserImpl.h
@@ -242,6 +242,56 @@ class AsmParserImpl : public BaseT {
return success();
}
+ /// Parse a floating point value from the stream.
+ ParseResult parseFloat(double &result) override {
+ bool isNegative = parser.consumeIf(Token::minus);
+ Token curTok = parser.getToken();
+ SMLoc loc = curTok.getLoc();
+
+ // Check for a floating point value.
+ if (curTok.is(Token::floatliteral)) {
+ auto val = curTok.getFloatingPointValue();
+ if (!val)
+ return emitError(loc, "floating point value too large");
+ parser.consumeToken(Token::floatliteral);
+ result = isNegative ? -*val : *val;
+ return success();
+ }
+
+ // Check for a hexadecimal float value.
+ if (curTok.is(Token::integer)) {
+ Optional<APFloat> apResult;
+ if (failed(parser.parseFloatFromIntegerLiteral(
+ apResult, curTok, isNegative, APFloat::IEEEdouble(),
+ /*typeSizeInBits=*/64)))
+ return failure();
+
+ parser.consumeToken(Token::integer);
+ result = apResult->convertToDouble();
+ return success();
+ }
+
+ return emitError(loc, "expected floating point literal");
+ }
+
+ /// Parse an optional integer value from the stream.
+ OptionalParseResult parseOptionalInteger(APInt &result) override {
+ return parser.parseOptionalInteger(result);
+ }
+
+ /// Parse a list of comma-separated items with an optional delimiter. If a
+ /// delimiter is provided, then an empty list is allowed. If not, then at
+ /// least one element will be parsed.
+ ParseResult parseCommaSeparatedList(Delimiter delimiter,
+ function_ref<ParseResult()> parseElt,
+ StringRef contextMessage) override {
+ return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Keyword Parsing
+ //===--------------------------------------------------------------------===//
+
ParseResult parseKeyword(StringRef keyword, const Twine &msg) override {
if (parser.getToken().isCodeCompletion())
return parser.codeCompleteExpectedTokens(keyword);
@@ -251,6 +301,7 @@ class AsmParserImpl : public BaseT {
return emitError(loc, "expected '") << keyword << "'" << msg;
return success();
}
+ using AsmParser::parseKeyword;
/// Parse the given keyword if present.
ParseResult parseOptionalKeyword(StringRef keyword) override {
@@ -308,52 +359,6 @@ class AsmParserImpl : public BaseT {
return parseOptionalString(result);
}
- /// Parse a floating point value from the stream.
- ParseResult parseFloat(double &result) override {
- bool isNegative = parser.consumeIf(Token::minus);
- Token curTok = parser.getToken();
- SMLoc loc = curTok.getLoc();
-
- // Check for a floating point value.
- if (curTok.is(Token::floatliteral)) {
- auto val = curTok.getFloatingPointValue();
- if (!val)
- return emitError(loc, "floating point value too large");
- parser.consumeToken(Token::floatliteral);
- result = isNegative ? -*val : *val;
- return success();
- }
-
- // Check for a hexadecimal float value.
- if (curTok.is(Token::integer)) {
- Optional<APFloat> apResult;
- if (failed(parser.parseFloatFromIntegerLiteral(
- apResult, curTok, isNegative, APFloat::IEEEdouble(),
- /*typeSizeInBits=*/64)))
- return failure();
-
- parser.consumeToken(Token::integer);
- result = apResult->convertToDouble();
- return success();
- }
-
- return emitError(loc, "expected floating point literal");
- }
-
- /// Parse an optional integer value from the stream.
- OptionalParseResult parseOptionalInteger(APInt &result) override {
- return parser.parseOptionalInteger(result);
- }
-
- /// Parse a list of comma-separated items with an optional delimiter. If a
- /// delimiter is provided, then an empty list is allowed. If not, then at
- /// least one element will be parsed.
- ParseResult parseCommaSeparatedList(Delimiter delimiter,
- function_ref<ParseResult()> parseElt,
- StringRef contextMessage) override {
- return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
- }
-
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
@@ -528,6 +533,28 @@ class AsmParserImpl : public BaseT {
return parser.parseXInDimensionList();
}
+ //===--------------------------------------------------------------------===//
+ // Code Completion
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a keyword, or an empty string if the current location signals a code
+ /// completion.
+ ParseResult parseKeywordOrCompletion(StringRef *keyword) override {
+ Token tok = parser.getToken();
+ if (tok.isCodeCompletion() && tok.getSpelling().empty()) {
+ *keyword = "";
+ return success();
+ }
+ return parseKeyword(keyword);
+ }
+
+ /// Signal the code completion of a set of expected tokens.
+ void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) override {
+ Token tok = parser.getToken();
+ if (tok.isCodeCompletion() && tok.getSpelling().empty())
+ (void)parser.codeCompleteExpectedTokens(tokens);
+ }
+
protected:
/// The source location of the dialect symbol.
SMLoc nameLoc;
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 177420668a385..3de6a0d60c238 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -213,6 +213,12 @@ Attribute Parser::parseAttribute(Type type) {
consumeToken(Token::kw_unit);
return builder.getUnitAttr();
+ // Handle completion of an attribute.
+ case Token::code_complete:
+ if (getToken().isCodeCompletionFor(Token::hash_identifier))
+ return parseExtendedAttr(type);
+ return codeCompleteAttribute();
+
default:
// Parse a type attribute. We parse `Optional` here to allow for providing a
// better error message.
diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 4478d1936e661..2e488332c9596 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -43,9 +43,6 @@ class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
};
} // namespace
-/// Parse the body of a dialect symbol, which starts and ends with <>'s, and may
-/// be recursive. Return with the 'body' StringRef encompassing the entire
-/// body.
///
/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
@@ -54,7 +51,8 @@ class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
/// | '{' pretty-dialect-sym-contents+ '}'
/// | '[^[<({>\])}\0]+'
///
-ParseResult Parser::parseDialectSymbolBody(StringRef &body) {
+ParseResult Parser::parseDialectSymbolBody(StringRef &body,
+ bool &isCodeCompletion) {
// Symbol bodies are a relatively unstructured format that contains a series
// of properly nested punctuation, with anything else in the middle. Scan
// ahead to find it and consume it if successful, otherwise emit an error.
@@ -65,7 +63,16 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body) {
// go until we find the matching '>' character.
assert(*curPtr == '<');
SmallVector<char, 8> nestedPunctuation;
+ const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
do {
+ // Handle code completions, which may appear in the middle of the symbol
+ // body.
+ if (curPtr == codeCompleteLoc) {
+ isCodeCompletion = true;
+ nestedPunctuation.clear();
+ break;
+ }
+
char c = *curPtr++;
switch (c) {
case '\0':
@@ -107,9 +114,19 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body) {
case '"': {
// Dispatch to the lexer to lex past strings.
resetToken(curPtr - 1);
+ curPtr = state.curToken.getEndLoc().getPointer();
+
+ // Handle code completions, which may appear in the middle of the symbol
+ // body.
+ if (state.curToken.isCodeCompletion()) {
+ isCodeCompletion = true;
+ nestedPunctuation.clear();
+ break;
+ }
+
+ // Otherwise, ensure this token was actually a string.
if (state.curToken.isNot(Token::string))
return failure();
- curPtr = state.curToken.getEndLoc().getPointer();
break;
}
@@ -129,19 +146,24 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body) {
/// Parse an extended dialect symbol.
template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
-static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
- SymbolAliasMap &aliases,
+static Symbol parseExtendedSymbol(Parser &p, SymbolAliasMap &aliases,
CreateFn &&createSymbol) {
+ Token tok = p.getToken();
+
+ // Handle code completion of the extended symbol.
+ StringRef identifier = tok.getSpelling().drop_front();
+ if (tok.isCodeCompletion() && identifier.empty())
+ return p.codeCompleteDialectSymbol(aliases);
+
// Parse the dialect namespace.
- StringRef identifier = p.getTokenSpelling().drop_front();
SMLoc loc = p.getToken().getLoc();
- p.consumeToken(identifierTok);
+ p.consumeToken();
// Check to see if this is a pretty name.
StringRef dialectName;
StringRef symbolData;
std::tie(dialectName, symbolData) = identifier.split('.');
- bool isPrettyName = !symbolData.empty();
+ bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
// Check to see if the symbol has trailing data, i.e. has an immediately
// following '<'.
@@ -167,9 +189,17 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
if (!isPrettyName) {
// Point the symbol data to the end of the dialect name to start.
symbolData = StringRef(dialectName.end(), 0);
- if (p.parseDialectSymbolBody(symbolData))
+
+ // Parse the body of the symbol.
+ bool isCodeCompletion = false;
+ if (p.parseDialectSymbolBody(symbolData, isCodeCompletion))
return nullptr;
- symbolData = symbolData.drop_front().drop_back();
+ symbolData = symbolData.drop_front();
+
+ // If the body contained a code completion it won't have the trailing `>`
+ // token, so don't drop it.
+ if (!isCodeCompletion)
+ symbolData = symbolData.drop_back();
} else {
loc = SMLoc::getFromPointer(symbolData.data());
@@ -192,7 +222,7 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
Attribute Parser::parseExtendedAttr(Type type) {
MLIRContext *ctx = getContext();
Attribute attr = parseExtendedSymbol<Attribute>(
- *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions,
+ *this, state.symbols.attributeAliasDefinitions,
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
// Parse an optional trailing colon type.
Type attrType = type;
@@ -238,7 +268,7 @@ Attribute Parser::parseExtendedAttr(Type type) {
Type Parser::parseExtendedType() {
MLIRContext *ctx = getContext();
return parseExtendedSymbol<Type>(
- *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions,
+ *this, state.symbols.typeAliasDefinitions,
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
// If we found a registered dialect, then ask it to parse the type.
if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
diff --git a/mlir/lib/Parser/Lexer.h b/mlir/lib/Parser/Lexer.h
index e09ae168e3bba..10ef3bf6429b4 100644
--- a/mlir/lib/Parser/Lexer.h
+++ b/mlir/lib/Parser/Lexer.h
@@ -40,6 +40,10 @@ class Lexer {
/// Returns the start of the buffer.
const char *getBufferBegin() { return curBuffer.data(); }
+ /// Return the code completion location of the lexer, or nullptr if there is
+ /// none.
+ const char *getCodeCompleteLoc() const { return codeCompleteLoc; }
+
private:
// Helpers.
Token formToken(Token::Kind kind, const char *tokStart) {
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index be4f886a9d0fb..e27f72c221714 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -404,6 +404,26 @@ ParseResult Parser::codeCompleteOptionalTokens(ArrayRef<StringRef> tokens) {
return failure();
}
+Attribute Parser::codeCompleteAttribute() {
+ state.codeCompleteContext->completeAttribute(
+ state.symbols.attributeAliasDefinitions);
+ return {};
+}
+Type Parser::codeCompleteType() {
+ state.codeCompleteContext->completeType(state.symbols.typeAliasDefinitions);
+ return {};
+}
+
+Attribute
+Parser::codeCompleteDialectSymbol(const llvm::StringMap<Attribute> &aliases) {
+ state.codeCompleteContext->completeDialectAttributeOrAlias(aliases);
+ return {};
+}
+Type Parser::codeCompleteDialectSymbol(const llvm::StringMap<Type> &aliases) {
+ state.codeCompleteContext->completeDialectTypeOrAlias(aliases);
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// OperationParser
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index 34a95a550e9d1..ce0d1b3b1b714 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -57,7 +57,16 @@ class Parser {
return parseCommaSeparatedList(Delimiter::None, parseElementFn);
}
- ParseResult parseDialectSymbolBody(StringRef &body);
+ /// Parse the body of a dialect symbol, which starts and ends with <>'s, and
+ /// may be recursive. Return with the 'body' StringRef encompassing the entire
+ /// body. `isCodeCompletion` is set to true if the body contained a code
+ /// completion location, in which case the body is only populated up to the
+ /// completion.
+ ParseResult parseDialectSymbolBody(StringRef &body, bool &isCodeCompletion);
+ ParseResult parseDialectSymbolBody(StringRef &body) {
+ bool isCodeCompletion = false;
+ return parseDialectSymbolBody(body, isCodeCompletion);
+ }
// We have two forms of parsing methods - those that return a non-null
// pointer on success, and those that return a ParseResult to indicate whether
@@ -322,6 +331,12 @@ class Parser {
ParseResult codeCompleteExpectedTokens(ArrayRef<StringRef> tokens);
ParseResult codeCompleteOptionalTokens(ArrayRef<StringRef> tokens);
+ Attribute codeCompleteAttribute();
+ Type codeCompleteType();
+ Attribute
+ codeCompleteDialectSymbol(const llvm::StringMap<Attribute> &aliases);
+ Type codeCompleteDialectSymbol(const llvm::StringMap<Type> &aliases);
+
protected:
/// The Parser is subclassed and reinstantiated. Do not add additional
/// non-trivial state here, add it to the ParserState class.
diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 660dfe1cbea28..a916ec72e5f4d 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -358,6 +358,12 @@ Type Parser::parseNonFunctionType() {
// extended type
case Token::exclamation_identifier:
return parseExtendedType();
+
+ // Handle completion of a dialect type.
+ case Token::code_complete:
+ if (getToken().isCodeCompletionFor(Token::exclamation_identifier))
+ return parseExtendedType();
+ return codeCompleteType();
}
}
diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.h b/mlir/lib/Tools/lsp-server-support/Protocol.h
index 03fe3ca6e41e7..0fc30d5bcced2 100644
--- a/mlir/lib/Tools/lsp-server-support/Protocol.h
+++ b/mlir/lib/Tools/lsp-server-support/Protocol.h
@@ -781,8 +781,9 @@ enum class InsertTextFormat {
struct CompletionItem {
CompletionItem() = default;
- CompletionItem(StringRef label, CompletionItemKind kind)
- : label(label.str()), kind(kind),
+ CompletionItem(const Twine &label, CompletionItemKind kind,
+ StringRef sortText = "")
+ : label(label.str()), kind(kind), sortText(sortText.str()),
insertTextFormat(InsertTextFormat::PlainText) {}
/// The label of this completion item. By default also the text that is
diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
index 9d3d51b48b2e1..a2f4f6f8d36d1 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
@@ -636,15 +636,17 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
: AsmParserCodeCompleteContext(completeLoc),
completionList(completionList), ctx(ctx) {}
- /// Signal code completion for a dialect name.
- void completeDialectName() final {
+ /// Signal code completion for a dialect name, with an optional prefix.
+ void completeDialectName(StringRef prefix) final {
for (StringRef dialect : ctx->getAvailableDialects()) {
- lsp::CompletionItem item(dialect, lsp::CompletionItemKind::Module);
- item.sortText = "2";
+ lsp::CompletionItem item(prefix + dialect,
+ lsp::CompletionItemKind::Module,
+ /*sortText=*/"3");
item.detail = "dialect";
completionList.items.emplace_back(item);
}
}
+ using AsmParserCodeCompleteContext::completeDialectName;
/// Signal code completion for an operation name within the given dialect.
void completeOperationName(StringRef dialectName) final {
@@ -658,8 +660,8 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
lsp::CompletionItem item(
op.getStringRef().drop_front(dialectName.size() + 1),
- lsp::CompletionItemKind::Field);
- item.sortText = "1";
+ lsp::CompletionItemKind::Field,
+ /*sortText=*/"1");
item.detail = "operation";
completionList.items.emplace_back(item);
}
@@ -693,13 +695,71 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
/// Signal a completion for the given expected token.
void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
for (StringRef token : tokens) {
- lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword);
- item.sortText = "0";
+ lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword,
+ /*sortText=*/"0");
item.detail = optional ? "optional" : "";
completionList.items.emplace_back(item);
}
}
+ /// Signal a completion for an attribute.
+ void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
+ appendSimpleCompletions({"affine_set", "affine_map", "dense", "false",
+ "loc", "opaque", "sparse", "true", "unit"},
+ lsp::CompletionItemKind::Field,
+ /*sortText=*/"1");
+
+ completeDialectName("#");
+ completeAliases(aliases, "#");
+ }
+ void completeDialectAttributeOrAlias(
+ const llvm::StringMap<Attribute> &aliases) override {
+ completeDialectName();
+ completeAliases(aliases);
+ }
+
+ /// Signal a completion for a type.
+ void completeType(const llvm::StringMap<Type> &aliases) override {
+ appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
+ "bf16", "f16", "f32", "f64", "f80", "f128",
+ "index", "none"},
+ lsp::CompletionItemKind::Field,
+ /*sortText=*/"1");
+ lsp::CompletionItem item("i<N>", lsp::CompletionItemKind::Field,
+ /*sortText=*/"1");
+ item.insertText = "i";
+ completionList.items.emplace_back(item);
+
+ completeDialectName("!");
+ completeAliases(aliases, "!");
+ }
+ void
+ completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
+ completeDialectName();
+ completeAliases(aliases);
+ }
+
+ /// Add completion results for the given set of aliases.
+ template <typename T>
+ void completeAliases(const llvm::StringMap<T> &aliases,
+ StringRef prefix = "") {
+ for (const auto &alias : aliases) {
+ lsp::CompletionItem item(prefix + alias.getKey(),
+ lsp::CompletionItemKind::Field,
+ /*sortText=*/"2");
+ llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
+ completionList.items.emplace_back(item);
+ }
+ }
+
+ /// Add a set of simple completions that all have the same kind.
+ void appendSimpleCompletions(ArrayRef<StringRef> completions,
+ lsp::CompletionItemKind kind,
+ StringRef sortText = "") {
+ for (StringRef completion : completions)
+ completionList.items.emplace_back(completion, kind, sortText);
+ }
+
private:
lsp::CompletionList &completionList;
MLIRContext *ctx;
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 5418025d501c9..df56055294385 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -408,12 +408,9 @@ void TestDialect::registerTypes() {
Type TestDialect::parseTestType(AsmParser &parser,
SetVector<Type> &stack) const {
StringRef typeTag;
- if (failed(parser.parseKeyword(&typeTag)))
- return Type();
-
{
Type genType;
- auto parseResult = generatedTypeParser(parser, typeTag, genType);
+ auto parseResult = generatedTypeParser(parser, &typeTag, genType);
if (parseResult.hasValue())
return genType;
}
diff --git a/mlir/test/mlir-lsp-server/completion.test b/mlir/test/mlir-lsp-server/completion.test
index 4264fe69bc737..fe067b256f7ac 100644
--- a/mlir/test/mlir-lsp-server/completion.test
+++ b/mlir/test/mlir-lsp-server/completion.test
@@ -5,14 +5,14 @@
"uri":"test:///foo.mlir",
"languageId":"mlir",
"version":1,
- "text":"func.func private @foo(%arg: i32) -> i32 {\n%cast = \"builtin.unrealized_conversion_cast\"() : () -> (i32)\nreturn %"
+ "text":"#attr = i32\n!alias = i32\nfunc.func private @foo(%arg: i32) -> i32 {\n%cast = \"builtin.unrealized_conversion_cast\"() : () -> (!pdl.value)\nreturn %"
}}}
// -----
{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
- "position":{"line":0,"character":0}
+ "position":{"line":2,"character":0}
}}
-// CHECK: "id": 1
+// CHECK-LABEL: "id": 1
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
@@ -22,7 +22,7 @@
// CHECK: "insertTextFormat": 1,
// CHECK: "kind": 9,
// CHECK: "label": "builtin",
-// CHECK: "sortText": "2"
+// CHECK: "sortText": "3"
// CHECK: },
// CHECK: {
// CHECK: "detail": "operation",
@@ -34,11 +34,11 @@
// CHECK: ]
// CHECK-NEXT: }
// -----
-{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+{"jsonrpc":"2.0","id":2,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
- "position":{"line":1,"character":9}
+ "position":{"line":3,"character":9}
}}
-// CHECK: "id": 1
+// CHECK-LABEL: "id": 2
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
@@ -48,17 +48,17 @@
// CHECK: "insertTextFormat": 1,
// CHECK: "kind": 9,
// CHECK: "label": "builtin",
-// CHECK: "sortText": "2"
+// CHECK: "sortText": "3"
// CHECK: },
// CHECK-NOT: "detail": "operation",
// CHECK: ]
// CHECK-NEXT: }
// -----
-{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+{"jsonrpc":"2.0","id":3,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
- "position":{"line":1,"character":17}
+ "position":{"line":3,"character":17}
}}
-// CHECK: "id": 1
+// CHECK-LABEL: "id": 3
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
@@ -74,17 +74,17 @@
// CHECK: ]
// CHECK-NEXT: }
// -----
-{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+{"jsonrpc":"2.0","id":4,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
- "position":{"line":2,"character":8}
+ "position":{"line":4,"character":8}
}}
-// CHECK: "id": 1
+// CHECK-LABEL: "id": 4
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
// CHECK-NEXT: "items": [
// CHECK-NEXT: {
-// CHECK-NEXT: "detail": "builtin.unrealized_conversion_cast: i32",
+// CHECK-NEXT: "detail": "builtin.unrealized_conversion_cast: !pdl.value",
// CHECK-NEXT: "insertText": "cast",
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 6,
@@ -100,11 +100,11 @@
// CHECK: ]
// CHECK-NEXT: }
// -----
-{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+{"jsonrpc":"2.0","id":5,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
- "position":{"line":0,"character":10}
+ "position":{"line":2,"character":10}
}}
-// CHECK: "id": 1
+// CHECK-LABEL: "id": 5
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
@@ -133,6 +133,134 @@
// CHECK-NEXT: ]
// CHECK-NEXT: }
// -----
-{"jsonrpc":"2.0","id":3,"method":"shutdown"}
+{"jsonrpc":"2.0","id":6,"method":"textDocument/completion","params":{
+ "textDocument":{"uri":"test:///foo.mlir"},
+ "position":{"line":0,"character":8}
+}}
+// CHECK-LABEL: "id": 6
+// CHECK-NEXT: "jsonrpc": "2.0",
+// CHECK-NEXT: "result": {
+// CHECK-NEXT: "isIncomplete": false,
+// CHECK-NEXT: "items": [
+// CHECK: {
+// CHECK: "insertTextFormat": 1,
+// CHECK: "kind": 5,
+// CHECK: "label": "false"
+// CHECK: },
+// CHECK: {
+// CHECK: "insertTextFormat": 1,
+// CHECK: "kind": 5,
+// CHECK: "label": "loc"
+// CHECK: },
+// CHECK: {
+// CHECK: "insertTextFormat": 1,
+// CHECK: "kind": 5,
+// CHECK: "label": "true"
+// CHECK: },
+// CHECK: {
+// CHECK: "insertTextFormat": 1,
+// CHECK: "kind": 5,
+// CHECK: "label": "unit"
+// CHECK: }
+// CHECK: ]
+// CHECK: }
+// -----
+{"jsonrpc":"2.0","id":7,"method":"textDocument/completion","params":{
+ "textDocument":{"uri":"test:///foo.mlir"},
+ "position":{"line":3,"character":56}
+}}
+// CHECK-LABEL: "id": 7
+// CHECK-NEXT: "jsonrpc": "2.0",
+// CHECK-NEXT: "result": {
+// CHECK-NEXT: "isIncomplete": false,
+// CHECK-NEXT: "items": [
+// CHECK: {
+// CHECK: "insertTextFormat": 1,
+// CHECK: "kind": 5,
+// CHECK: "label": "index"
+// CHECK: },
+// CHECK: {
+// CHECK: "insertTextFormat": 1,
+// CHECK: "kind": 5,
+// CHECK: "label": "none"
+// CHECK: },
+// CHECK: {
+// CHECK: "insertText": "i",
+// CHECK: "insertTextFormat": 1,
+// CHECK: "kind": 5,
+// CHECK: "label": "i<N>"
+// CHECK: }
+// CHECK: ]
+// CHECK-NEXT: }
+// -----
+{"jsonrpc":"2.0","id":8,"method":"textDocument/completion","params":{
+ "textDocument":{"uri":"test:///foo.mlir"},
+ "position":{"line":3,"character":57}
+}}
+// CHECK-LABEL: "id": 8
+// CHECK-NEXT: "jsonrpc": "2.0",
+// CHECK-NEXT: "result": {
+// CHECK-NEXT: "isIncomplete": false,
+// CHECK-NEXT: "items": [
+// CHECK: {
+// CHECK: "detail": "dialect",
+// CHECK: "insertTextFormat": 1,
+// CHECK: "kind": 9,
+// CHECK: "label": "builtin",
+// CHECK: "sortText": "3"
+// CHECK: },
+// CHECK: {
+// CHECK: "detail": "alias: i32",
+// CHECK: "insertTextFormat": 1,
+// CHECK: "kind": 5,
+// CHECK: "label": "alias",
+// CHECK: "sortText": "2"
+// CHECK: }
+// CHECK: ]
+// CHECK-NEXT: }
+// -----
+{"jsonrpc":"2.0","id":9,"method":"textDocument/completion","params":{
+ "textDocument":{"uri":"test:///foo.mlir"},
+ "position":{"line":3,"character":61}
+}}
+// CHECK-LABEL: "id": 9
+// CHECK-NEXT: "jsonrpc": "2.0",
+// CHECK-NEXT: "result": {
+// CHECK-NEXT: "isIncomplete": false,
+// CHECK-NEXT: "items": [
+// CHECK-NEXT: {
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 14,
+// CHECK-NEXT: "label": "attribute",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 14,
+// CHECK-NEXT: "label": "operation",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 14,
+// CHECK-NEXT: "label": "range",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 14,
+// CHECK-NEXT: "label": "type",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 14,
+// CHECK-NEXT: "label": "value",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: }
+// CHECK-NEXT: ]
+// CHECK-NEXT: }
+// -----
+{"jsonrpc":"2.0","id":10,"method":"shutdown"}
// -----
{"jsonrpc":"2.0","method":"exit"}
diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 3737ce77948e2..50dba45c63d0b 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -21,16 +21,19 @@ include "mlir/IR/OpBase.td"
// DEF-LABEL: ::mlir::OptionalParseResult generatedAttributeParser(
// DEF-SAME: ::mlir::AsmParser &parser,
-// DEF-SAME: ::llvm::StringRef mnemonic, ::mlir::Type type,
+// DEF-SAME: ::llvm::StringRef *mnemonic, ::mlir::Type type,
// DEF-SAME: ::mlir::Attribute &value) {
-// DEF: if (mnemonic == ::test::CompoundAAttr::getMnemonic()) {
+// DEF: return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
+// DEF: .Case(::test::CompoundAAttr::getMnemonic()
// DEF-NEXT: value = ::test::CompoundAAttr::parse(parser, type);
// DEF-NEXT: return ::mlir::success(!!value);
-// DEF-NEXT: }
-// DEF-NEXT: if (mnemonic == ::test::IndexAttr::getMnemonic()) {
+// DEF-NEXT: })
+// DEF-NEXT: .Case(::test::IndexAttr::getMnemonic()
// DEF-NEXT: value = ::test::IndexAttr::parse(parser, type);
// DEF-NEXT: return ::mlir::success(!!value);
-// DEF: return {};
+// DEF: .Default([&](llvm::StringRef keyword,
+// DEF-NEXT: *mnemonic = keyword;
+// DEF-NEXT: return llvm::None;
def Test_Dialect: Dialect {
// DECL-NOT: TestDialect
diff --git a/mlir/test/mlir-tblgen/default-type-attr-print-parser.td b/mlir/test/mlir-tblgen/default-type-attr-print-parser.td
index ac898a97d9eee..1c479f11abbf6 100644
--- a/mlir/test/mlir-tblgen/default-type-attr-print-parser.td
+++ b/mlir/test/mlir-tblgen/default-type-attr-print-parser.td
@@ -27,11 +27,9 @@ def AttrA : TestAttr<"AttrA"> {
// ATTR: ::mlir::Type type) const {
// ATTR: ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
// ATTR: ::llvm::StringRef attrTag;
-// ATTR: if (::mlir::failed(parser.parseKeyword(&attrTag)))
-// ATTR: return {};
// ATTR: {
// ATTR: ::mlir::Attribute attr;
-// ATTR: auto parseResult = generatedAttributeParser(parser, attrTag, type, attr);
+// ATTR: auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
// ATTR: if (parseResult.hasValue())
// ATTR: return attr;
// ATTR: }
@@ -57,10 +55,8 @@ def TypeA : TestType<"TypeA"> {
// TYPE: ::mlir::Type TestDialect::parseType(::mlir::DialectAsmParser &parser) const {
// TYPE: ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
// TYPE: ::llvm::StringRef mnemonic;
-// TYPE: if (parser.parseKeyword(&mnemonic))
-// TYPE: return ::mlir::Type();
// TYPE: ::mlir::Type genType;
-// TYPE: auto parseResult = generatedTypeParser(parser, mnemonic, genType);
+// TYPE: auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
// TYPE: if (parseResult.hasValue())
// TYPE: return genType;
// TYPE: parser.emitError(typeLoc) << "unknown type `"
diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index 833195530f9b0..29ebad5439c08 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -22,16 +22,18 @@ include "mlir/IR/OpBase.td"
// DEF-LABEL: ::mlir::OptionalParseResult generatedTypeParser(
// DEF-SAME: ::mlir::AsmParser &parser,
-// DEF-SAME: ::llvm::StringRef mnemonic,
+// DEF-SAME: ::llvm::StringRef *mnemonic,
// DEF-SAME: ::mlir::Type &value) {
-// DEF: if (mnemonic == ::test::CompoundAType::getMnemonic()) {
+// DEF: .Case(::test::CompoundAType::getMnemonic()
// DEF-NEXT: value = ::test::CompoundAType::parse(parser);
// DEF-NEXT: return ::mlir::success(!!value);
-// DEF-NEXT: }
-// DEF-NEXT: if (mnemonic == ::test::IndexType::getMnemonic()) {
+// DEF-NEXT: })
+// DEF-NEXT: .Case(::test::IndexType::getMnemonic()
// DEF-NEXT: value = ::test::IndexType::parse(parser);
// DEF-NEXT: return ::mlir::success(!!value);
-// DEF: return {};
+// DEF: .Default([&](llvm::StringRef keyword,
+// DEF-NEXT: *mnemonic = keyword;
+// DEF-NEXT: return llvm::None;
def Test_Dialect: Dialect {
// DECL-NOT: TestDialect
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 759143dc606b0..db443e9ac0cff 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -673,11 +673,9 @@ ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
::mlir::Type type) const {{
::llvm::SMLoc typeLoc = parser.getCurrentLocation();
::llvm::StringRef attrTag;
- if (::mlir::failed(parser.parseKeyword(&attrTag)))
- return {{};
{{
::mlir::Attribute attr;
- auto parseResult = generatedAttributeParser(parser, attrTag, type, attr);
+ auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
if (parseResult.hasValue())
return attr;
}
@@ -723,10 +721,8 @@ static const char *const dialectDefaultTypePrinterParserDispatch = R"(
::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
::llvm::SMLoc typeLoc = parser.getCurrentLocation();
::llvm::StringRef mnemonic;
- if (parser.parseKeyword(&mnemonic))
- return ::mlir::Type();
::mlir::Type genType;
- auto parseResult = generatedTypeParser(parser, mnemonic, genType);
+ auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
if (parseResult.hasValue())
return genType;
{1}
@@ -771,7 +767,7 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
}
// Declare the parser.
SmallVector<MethodParameter> params = {{"::mlir::AsmParser &", "parser"},
- {"::llvm::StringRef", "mnemonic"}};
+ {"::llvm::StringRef *", "mnemonic"}};
if (isAttrGenerator)
params.emplace_back("::mlir::Type", "type");
params.emplace_back(strfmt("::mlir::{0} &", valueType), "value");
@@ -784,14 +780,18 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
{{strfmt("::mlir::{0}", valueType), "def"},
{"::mlir::AsmPrinter &", "printer"}});
- // The parser dispatch is just a list of if-elses, matching on the mnemonic
- // and calling the def's parse function.
+ // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and
+ // calling the def's parse function.
+ parse.body() << " return "
+ "::mlir::AsmParser::KeywordSwitch<::mlir::"
+ "OptionalParseResult>(parser)\n";
const char *const getValueForMnemonic =
- R"( if (mnemonic == {0}::getMnemonic()) {{
- value = {0}::{1};
- return ::mlir::success(!!value);
- }
+ R"( .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{
+ value = {0}::{1};
+ return ::mlir::success(!!value);
+ })
)";
+
// The printer dispatch uses llvm::TypeSwitch to find and call the correct
// printer.
printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType
@@ -822,7 +822,10 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
printDef = "\nt.print(printer);";
printer.body() << llvm::formatv(printValue, defClass, printDef);
}
- parse.body() << " return {};";
+ parse.body() << " .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n"
+ " *mnemonic = keyword;\n"
+ " return llvm::None;\n"
+ " });";
printer.body() << " .Default([](auto) { return ::mlir::failure(); });";
raw_indented_ostream indentedOs(os);
More information about the Mlir-commits
mailing list