[Mlir-commits] [mlir] ed2fb17 - [mlir:LSP] Add support for MLIR code completions
River Riddle
llvmlistbot at llvm.org
Thu Jul 7 13:36:30 PDT 2022
Author: River Riddle
Date: 2022-07-07T13:35:54-07:00
New Revision: ed2fb1736ac1515838d136b57e5ad9a5c805001b
URL: https://github.com/llvm/llvm-project/commit/ed2fb1736ac1515838d136b57e5ad9a5c805001b
DIFF: https://github.com/llvm/llvm-project/commit/ed2fb1736ac1515838d136b57e5ad9a5c805001b.diff
LOG: [mlir:LSP] Add support for MLIR code completions
This commit adds code completion results to the MLIR LSP using
a new code completion context in the MLIR parser. This commit
adds initial completion for dialect, operation, SSA value, and
block names.
Differential Revision: https://reviews.llvm.org/D129183
Added:
mlir/include/mlir/Parser/CodeComplete.h
mlir/test/mlir-lsp-server/completion.test
Modified:
mlir/include/mlir/Parser/Parser.h
mlir/lib/Parser/AffineParser.cpp
mlir/lib/Parser/DialectSymbolParser.cpp
mlir/lib/Parser/Lexer.cpp
mlir/lib/Parser/Lexer.h
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/Parser.h
mlir/lib/Parser/ParserState.h
mlir/lib/Parser/Token.cpp
mlir/lib/Parser/Token.h
mlir/lib/Parser/TokenKinds.def
mlir/lib/Tools/lsp-server-support/Protocol.h
mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
mlir/lib/Tools/mlir-lsp-server/MLIRServer.h
mlir/test/mlir-lsp-server/initialize-params.test
Removed:
################################################################################
diff --git a/mlir/include/mlir/Parser/CodeComplete.h b/mlir/include/mlir/Parser/CodeComplete.h
new file mode 100644
index 0000000000000..24d0181456f67
--- /dev/null
+++ b/mlir/include/mlir/Parser/CodeComplete.h
@@ -0,0 +1,58 @@
+//===- CodeComplete.h - MLIR Asm CodeComplete Context -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_PARSER_CODECOMPLETE_H
+#define MLIR_PARSER_CODECOMPLETE_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir {
+/// This class provides an abstract interface into the parser for hooking in
+/// code completion events. This class is only really useful for providing
+/// language tooling for MLIR, general clients should not need to use this
+/// class.
+class AsmParserCodeCompleteContext {
+public:
+ virtual ~AsmParserCodeCompleteContext();
+
+ /// Return the source location used to provide code completion.
+ SMLoc getCodeCompleteLoc() const { return codeCompleteLoc; }
+
+ //===--------------------------------------------------------------------===//
+ // Completion Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// Signal code completion for a dialect name.
+ virtual void completeDialectName() = 0;
+
+ /// Signal code completion for an operation name within the given dialect.
+ virtual void completeOperationName(StringRef dialectName) = 0;
+
+ /// Append the given SSA value as a code completion result for SSA value
+ /// completions.
+ virtual void appendSSAValueCompletion(StringRef name,
+ std::string typeData) = 0;
+
+ /// Append the given block as a code completion result for block name
+ /// completions.
+ virtual void appendBlockCompletion(StringRef name) = 0;
+
+protected:
+ /// Create a new code completion context with the given code complete
+ /// location.
+ explicit AsmParserCodeCompleteContext(SMLoc codeCompleteLoc)
+ : codeCompleteLoc(codeCompleteLoc) {}
+
+private:
+ /// The location used to code complete.
+ SMLoc codeCompleteLoc;
+};
+} // namespace mlir
+
+#endif // MLIR_PARSER_CODECOMPLETE_H
diff --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h
index 69f02c45d6fbb..5c9543f187d6a 100644
--- a/mlir/include/mlir/Parser/Parser.h
+++ b/mlir/include/mlir/Parser/Parser.h
@@ -26,6 +26,7 @@ class StringRef;
namespace mlir {
class AsmParserState;
+class AsmParserCodeCompleteContext;
namespace detail {
@@ -83,11 +84,13 @@ inline OwningOpRef<ContainerOpT> constructContainerOpForParserIfNecessary(
/// source file that is being parsed. If `asmState` is non-null, it is populated
/// with detailed information about the parsed IR (including exact locations for
/// SSA uses and definitions). `asmState` should only be provided if this
-/// detailed information is desired.
-LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
- const ParserConfig &config,
- LocationAttr *sourceFileLoc = nullptr,
- AsmParserState *asmState = nullptr);
+/// detailed information is desired. If `codeCompleteContext` is non-null, it is
+/// used to signal tracking of a code completion event (generally only ever
+/// useful for LSP or other high level language tooling).
+LogicalResult parseSourceFile(
+ const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config,
+ LocationAttr *sourceFileLoc = nullptr, AsmParserState *asmState = nullptr,
+ AsmParserCodeCompleteContext *codeCompleteContext = nullptr);
/// This parses the file specified by the indicated filename and appends parsed
/// operations to the given block. If the block is non-empty, the operations are
diff --git a/mlir/lib/Parser/AffineParser.cpp b/mlir/lib/Parser/AffineParser.cpp
index 7fa054bf3913a..345b3a505b069 100644
--- a/mlir/lib/Parser/AffineParser.cpp
+++ b/mlir/lib/Parser/AffineParser.cpp
@@ -743,7 +743,8 @@ IntegerSet mlir::parseIntegerSet(StringRef inputStr, MLIRContext *context,
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
SymbolState symbolState;
ParserConfig config(context);
- ParserState state(sourceMgr, config, symbolState, /*asmState=*/nullptr);
+ ParserState state(sourceMgr, config, symbolState, /*asmState=*/nullptr,
+ /*codeCompleteContext=*/nullptr);
Parser parser(state);
raw_ostream &os = printDiagnosticInfo ? llvm::errs() : llvm::nulls();
diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 623ea8ef12ea1..4478d1936e661 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -277,7 +277,8 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
SymbolState aliasState;
ParserConfig config(context);
- ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr);
+ ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
+ /*codeCompleteContext=*/nullptr);
Parser parser(state);
SourceMgrDiagnosticHandler handler(
diff --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp
index f3e4ce6b2768e..cbebaf24159d4 100644
--- a/mlir/lib/Parser/Lexer.cpp
+++ b/mlir/lib/Parser/Lexer.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/Parser/CodeComplete.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/SourceMgr.h"
@@ -26,11 +27,16 @@ static bool isPunct(char c) {
return c == '$' || c == '.' || c == '_' || c == '-';
}
-Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context)
- : sourceMgr(sourceMgr), context(context) {
+Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context,
+ AsmParserCodeCompleteContext *codeCompleteContext)
+ : sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) {
auto bufferID = sourceMgr.getMainFileID();
curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
curPtr = curBuffer.begin();
+
+ // Set the code completion location if it was provided.
+ if (codeCompleteContext)
+ codeCompleteLoc = codeCompleteContext->getCodeCompleteLoc().getPointer();
}
/// Encode the specified source location information into an attribute for
@@ -61,6 +67,12 @@ Token Lexer::emitError(const char *loc, const Twine &message) {
Token Lexer::lexToken() {
while (true) {
const char *tokStart = curPtr;
+
+ // Check to see if the current token is at the code completion location.
+ if (tokStart == codeCompleteLoc)
+ return formToken(Token::code_complete, tokStart);
+
+ // Lex the next token.
switch (*curPtr++) {
default:
// Handle bare identifiers.
@@ -357,17 +369,25 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) {
// Parse suffix-id.
if (isdigit(*curPtr)) {
// If suffix-id starts with a digit, the rest must be digits.
- while (isdigit(*curPtr)) {
+ while (isdigit(*curPtr))
++curPtr;
- }
} else if (isalpha(*curPtr) || isPunct(*curPtr)) {
do {
++curPtr;
} while (isalpha(*curPtr) || isdigit(*curPtr) || isPunct(*curPtr));
+ } else if (curPtr == codeCompleteLoc) {
+ return formToken(Token::code_complete, tokStart);
} else {
return emitError(curPtr - 1, errorKind);
}
+ // Check for a code completion within the identifier.
+ if (codeCompleteLoc && codeCompleteLoc >= tokStart &&
+ codeCompleteLoc <= curPtr) {
+ return Token(Token::code_complete,
+ StringRef(tokStart, codeCompleteLoc - tokStart));
+ }
+
return formToken(kind, tokStart);
}
@@ -380,6 +400,13 @@ Token Lexer::lexString(const char *tokStart) {
assert(curPtr[-1] == '"');
while (true) {
+ // Check to see if there is a code completion location within the string. In
+ // these cases we generate a completion location and place the currently
+ // lexed string within the token. This allows for the parser to use the
+ // partially lexed string when computing the completion results.
+ if (curPtr == codeCompleteLoc)
+ return formToken(Token::code_complete, tokStart);
+
switch (*curPtr++) {
case '"':
return formToken(Token::string, tokStart);
diff --git a/mlir/lib/Parser/Lexer.h b/mlir/lib/Parser/Lexer.h
index 58b5e2321c4d3..e09ae168e3bba 100644
--- a/mlir/lib/Parser/Lexer.h
+++ b/mlir/lib/Parser/Lexer.h
@@ -22,7 +22,8 @@ class Location;
/// This class breaks up the current file into a token stream.
class Lexer {
public:
- explicit Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context);
+ explicit Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context,
+ AsmParserCodeCompleteContext *codeCompleteContext);
const llvm::SourceMgr &getSourceMgr() { return sourceMgr; }
@@ -64,6 +65,10 @@ class Lexer {
StringRef curBuffer;
const char *curPtr;
+ /// An optional code completion point within the input file, used to indicate
+ /// the position of a code completion token.
+ const char *codeCompleteLoc;
+
Lexer(const Lexer &) = delete;
void operator=(const Lexer &) = delete;
};
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 5bc1e43c60a0e..51a65a99dfa7a 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser/AsmParserState.h"
+#include "mlir/Parser/CodeComplete.h"
#include "mlir/Parser/Parser.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ScopeExit.h"
@@ -32,6 +33,12 @@ using namespace mlir::detail;
using llvm::MemoryBuffer;
using llvm::SourceMgr;
+//===----------------------------------------------------------------------===//
+// CodeComplete
+//===----------------------------------------------------------------------===//
+
+AsmParserCodeCompleteContext::~AsmParserCodeCompleteContext() = default;
+
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
@@ -334,6 +341,60 @@ Parser::parseResourceHandle(const OpAsmDialectInterface *dialect,
return entry.second;
}
+//===----------------------------------------------------------------------===//
+// Code Completion
+
+ParseResult Parser::codeCompleteDialectName() {
+ state.codeCompleteContext->completeDialectName();
+ return failure();
+}
+
+ParseResult Parser::codeCompleteOperationName(StringRef dialectName) {
+ // Perform some simple validation on the dialect name. This doesn't need to be
+ // extensive, it's more of an optimization (to avoid checking completion
+ // results when we know they will fail).
+ if (dialectName.empty() || dialectName.contains('.'))
+ return failure();
+ state.codeCompleteContext->completeOperationName(dialectName);
+ return failure();
+}
+
+ParseResult Parser::codeCompleteDialectOrElidedOpName(SMLoc loc) {
+ // Check to see if there is anything else on the current line. This check
+ // isn't strictly necessary, but it does avoid unnecessarily triggering
+ // completions for operations and dialects in situations where we don't want
+ // them (e.g. at the end of an operation).
+ auto shouldIgnoreOpCompletion = [&]() {
+ const char *bufBegin = state.lex.getBufferBegin();
+ const char *it = loc.getPointer() - 1;
+ for (; it > bufBegin && *it != '\n'; --it)
+ if (!llvm::is_contained(StringRef(" \t\r"), *it))
+ return true;
+ return false;
+ };
+ if (shouldIgnoreOpCompletion())
+ return failure();
+
+ // The completion here is either for a dialect name, or an operation name
+ // whose dialect prefix was elided. For this we simply invoke both of the
+ // individual completion methods.
+ (void)codeCompleteDialectName();
+ return codeCompleteOperationName(state.defaultDialectStack.back());
+}
+
+ParseResult Parser::codeCompleteStringDialectOrOperationName(StringRef name) {
+ // If the name is empty, this is the start of the string and contains the
+ // dialect.
+ if (name.empty())
+ return codeCompleteDialectName();
+
+ // Otherwise, we treat this as completing an operation name. The current name
+ // is used as the dialect namespace.
+ if (name.consume_back("."))
+ return codeCompleteOperationName(name);
+ return failure();
+}
+
//===----------------------------------------------------------------------===//
// OperationParser
//===----------------------------------------------------------------------===//
@@ -497,6 +558,17 @@ class OperationParser : public Parser {
/// us to diagnose references to blocks that are not defined precisely.
Block *getBlockNamed(StringRef name, SMLoc loc);
+ //===--------------------------------------------------------------------===//
+ // Code Completion
+ //===--------------------------------------------------------------------===//
+
+ /// The set of various code completion methods. Every completion method
+ /// returns `failure` to stop the parsing process after providing completion
+ /// results.
+
+ ParseResult codeCompleteSSAUse();
+ ParseResult codeCompleteBlock();
+
private:
/// This class represents a definition of a Block.
struct BlockDefinition {
@@ -790,7 +862,7 @@ ParseResult OperationParser::addDefinition(UnresolvedOperand useInfo,
///
ParseResult OperationParser::parseOptionalSSAUseList(
SmallVectorImpl<UnresolvedOperand> &results) {
- if (getToken().isNot(Token::percent_identifier))
+ if (!getToken().isOrIsCodeCompletionFor(Token::percent_identifier))
return success();
return parseCommaSeparatedList([&]() -> ParseResult {
UnresolvedOperand result;
@@ -807,6 +879,9 @@ ParseResult OperationParser::parseOptionalSSAUseList(
///
ParseResult OperationParser::parseSSAUse(UnresolvedOperand &result,
bool allowResultNumber) {
+ if (getToken().isCodeCompletion())
+ return codeCompleteSSAUse();
+
result.name = getTokenSpelling();
result.number = 0;
result.location = getToken().getLoc();
@@ -1017,6 +1092,10 @@ ParseResult OperationParser::parseOperation() {
op = parseCustomOperation(resultIDs);
else if (nameTok.is(Token::string))
op = parseGenericOperation();
+ else if (nameTok.isCodeCompletionFor(Token::string))
+ return codeCompleteStringDialectOrOperationName(nameTok.getStringValue());
+ else if (nameTok.isCodeCompletion())
+ return codeCompleteDialectOrElidedOpName(loc);
else
return emitWrongTokenError("expected operation name in quotes");
@@ -1071,6 +1150,9 @@ ParseResult OperationParser::parseOperation() {
/// successor ::= block-id
///
ParseResult OperationParser::parseSuccessor(Block *&dest) {
+ if (getToken().isCodeCompletion())
+ return codeCompleteBlock();
+
// Verify branch is identifier and get the matching block.
if (!getToken().is(Token::caret_identifier))
return emitWrongTokenError("expected block name");
@@ -1391,7 +1473,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
OptionalParseResult
parseOptionalOperand(UnresolvedOperand &result,
bool allowResultNumber = true) override {
- if (parser.getToken().is(Token::percent_identifier))
+ if (parser.getToken().isOrIsCodeCompletionFor(Token::percent_identifier))
return parseOperand(result, allowResultNumber);
return llvm::None;
}
@@ -1406,14 +1488,15 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
if (delimiter == Delimiter::None) {
// parseCommaSeparatedList doesn't handle the missing case for "none",
// so we handle it custom here.
- if (parser.getToken().isNot(Token::percent_identifier)) {
+ Token tok = parser.getToken();
+ if (!tok.isOrIsCodeCompletionFor(Token::percent_identifier)) {
// If we didn't require any operands or required exactly zero (weird)
// then this is success.
if (requiredOperandCount == -1 || requiredOperandCount == 0)
return success();
// Otherwise, try to produce a nice error message.
- if (parser.getToken().isAny(Token::l_paren, Token::l_square))
+ if (tok.isAny(Token::l_paren, Token::l_square))
return parser.emitError("unexpected delimiter");
return parser.emitWrongTokenError("expected operand");
}
@@ -1597,7 +1680,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
/// Parse an optional operation successor and its operand list.
OptionalParseResult parseOptionalSuccessor(Block *&dest) override {
- if (parser.getToken().isNot(Token::caret_identifier))
+ if (!parser.getToken().isOrIsCodeCompletionFor(Token::caret_identifier))
return llvm::None;
return parseSuccessor(dest);
}
@@ -1681,7 +1764,8 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
} // namespace
FailureOr<OperationName> OperationParser::parseCustomOperationName() {
- std::string opName = getTokenSpelling().str();
+ Token nameTok = getToken();
+ StringRef opName = nameTok.getSpelling();
if (opName.empty())
return (emitError("empty operation name is invalid"), failure());
consumeToken();
@@ -1694,11 +1778,17 @@ FailureOr<OperationName> OperationParser::parseCustomOperationName() {
// If the operation doesn't have a dialect prefix try using the default
// dialect.
- auto opNameSplit = StringRef(opName).split('.');
+ auto opNameSplit = opName.split('.');
StringRef dialectName = opNameSplit.first;
+ std::string opNameStorage;
if (opNameSplit.second.empty()) {
+ // If the name didn't have a prefix, check for a code completion request.
+ if (getToken().isCodeCompletion() && opName.back() == '.')
+ return codeCompleteOperationName(dialectName);
+
dialectName = getState().defaultDialectStack.back();
- opName = (dialectName + "." + opName).str();
+ opNameStorage = (dialectName + "." + opName).str();
+ opName = opNameStorage;
}
// Try to load the dialect before returning the operation name to make sure
@@ -2091,6 +2181,58 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
});
}
+//===----------------------------------------------------------------------===//
+// Code Completion
+//===----------------------------------------------------------------------===//
+
+ParseResult OperationParser::codeCompleteSSAUse() {
+ std::string detailData;
+ llvm::raw_string_ostream detailOS(detailData);
+ for (IsolatedSSANameScope &scope : isolatedNameScopes) {
+ for (auto &it : scope.values) {
+ if (it.second.empty())
+ continue;
+ Value frontValue = it.second.front().value;
+
+ // If the value isn't a forward reference, we also add the name of the op
+ // to the detail.
+ if (auto result = frontValue.dyn_cast<OpResult>()) {
+ if (!forwardRefPlaceholders.count(result))
+ detailOS << result.getOwner()->getName() << ": ";
+ } else {
+ detailOS << "arg #" << frontValue.cast<BlockArgument>().getArgNumber()
+ << ": ";
+ }
+
+ // Emit the type of the values to aid with completion selection.
+ detailOS << frontValue.getType();
+
+ // FIXME: We should define a policy for packed values, e.g. with a limit
+ // on the detail size, but it isn't clear what would be useful right now.
+ // For now we just only emit the first type.
+ if (it.second.size() > 1)
+ detailOS << ", ...";
+
+ state.codeCompleteContext->appendSSAValueCompletion(
+ it.getKey(), std::move(detailOS.str()));
+ }
+ }
+
+ return failure();
+}
+
+ParseResult OperationParser::codeCompleteBlock() {
+ // Don't provide completions if the token isn't empty, e.g. this avoids
+ // weirdness when we encounter a `.` within the identifier.
+ StringRef spelling = getTokenSpelling();
+ if (!(spelling.empty() || spelling == "^"))
+ return failure();
+
+ for (const auto &it : blocksByName.back())
+ state.codeCompleteContext->appendBlockCompletion(it.getFirst());
+ return failure();
+}
+
//===----------------------------------------------------------------------===//
// Top-level entity parsing.
//===----------------------------------------------------------------------===//
@@ -2418,10 +2560,11 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
//===----------------------------------------------------------------------===//
-LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
- Block *block, const ParserConfig &config,
- LocationAttr *sourceFileLoc,
- AsmParserState *asmState) {
+LogicalResult
+mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
+ const ParserConfig &config, LocationAttr *sourceFileLoc,
+ AsmParserState *asmState,
+ AsmParserCodeCompleteContext *codeCompleteContext) {
const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
Location parserLoc =
@@ -2431,7 +2574,8 @@ LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
*sourceFileLoc = parserLoc;
SymbolState aliasState;
- ParserState state(sourceMgr, config, aliasState, asmState);
+ ParserState state(sourceMgr, config, aliasState, asmState,
+ codeCompleteContext);
return TopLevelOperationParser(state).parse(block, parserLoc);
}
diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index e762fc6d3543b..f1c87ca4916b5 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -306,6 +306,20 @@ class Parser {
parseAffineExprOfSSAIds(AffineExpr &expr,
function_ref<ParseResult(bool)> parseElement);
+ //===--------------------------------------------------------------------===//
+ // Code Completion
+ //===--------------------------------------------------------------------===//
+
+ /// The set of various code completion methods. Every completion method
+ /// returns `failure` to signal that parsing should abort after any desired
+ /// completions have been enqueued. Note that `failure` is does not mean
+ /// completion failed, it's just a signal to the parser to stop.
+
+ ParseResult codeCompleteDialectName();
+ ParseResult codeCompleteOperationName(StringRef dialectName);
+ ParseResult codeCompleteDialectOrElidedOpName(SMLoc loc);
+ ParseResult codeCompleteStringDialectOrOperationName(StringRef name);
+
protected:
/// The Parser is subclassed and reinstantiated. Do not add additional
/// non-trivial state here, add it to the ParserState class.
diff --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h
index 4d3ff2774f8e4..ec64a43c7bb8b 100644
--- a/mlir/lib/Parser/ParserState.h
+++ b/mlir/lib/Parser/ParserState.h
@@ -43,9 +43,12 @@ struct SymbolState {
/// such as the current lexer position etc.
struct ParserState {
ParserState(const llvm::SourceMgr &sourceMgr, const ParserConfig &config,
- SymbolState &symbols, AsmParserState *asmState)
- : config(config), lex(sourceMgr, config.getContext()),
- curToken(lex.lexToken()), symbols(symbols), asmState(asmState) {}
+ SymbolState &symbols, AsmParserState *asmState,
+ AsmParserCodeCompleteContext *codeCompleteContext)
+ : config(config),
+ lex(sourceMgr, config.getContext(), codeCompleteContext),
+ curToken(lex.lexToken()), symbols(symbols), asmState(asmState),
+ codeCompleteContext(codeCompleteContext) {}
ParserState(const ParserState &) = delete;
void operator=(const ParserState &) = delete;
@@ -65,6 +68,9 @@ struct ParserState {
/// populated during parsing.
AsmParserState *asmState;
+ /// An optional code completion context.
+ AsmParserCodeCompleteContext *codeCompleteContext;
+
// Contains the stack of default dialect to use when parsing regions.
// A new dialect get pushed to the stack before parsing regions nested
// under an operation implementing `OpAsmOpInterface`, and
diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp
index 23410b37088d4..7a8d2975d2f47 100644
--- a/mlir/lib/Parser/Token.cpp
+++ b/mlir/lib/Parser/Token.cpp
@@ -78,12 +78,15 @@ Optional<bool> Token::getIntTypeSignedness() const {
/// removing the quote characters and unescaping the contents of the string. The
/// lexer has already verified that this token is valid.
std::string Token::getStringValue() const {
- assert(getKind() == string ||
+ assert(getKind() == string || getKind() == code_complete ||
(getKind() == at_identifier && getSpelling()[1] == '"'));
// Start by dropping the quotes.
- StringRef bytes = getSpelling().drop_front().drop_back();
- if (getKind() == at_identifier)
- bytes = bytes.drop_front();
+ StringRef bytes = getSpelling().drop_front();
+ if (getKind() != Token::code_complete) {
+ bytes = bytes.drop_back();
+ if (getKind() == at_identifier)
+ bytes = bytes.drop_front();
+ }
std::string result;
result.reserve(bytes.size());
@@ -190,3 +193,22 @@ bool Token::isKeyword() const {
#include "TokenKinds.def"
}
}
+
+bool Token::isCodeCompletionFor(Kind kind) const {
+ if (!isCodeCompletion() || spelling.empty())
+ return false;
+ switch (kind) {
+ case Kind::string:
+ return spelling[0] == '"';
+ case Kind::hash_identifier:
+ return spelling[0] == '#';
+ case Kind::percent_identifier:
+ return spelling[0] == '%';
+ case Kind::caret_identifier:
+ return spelling[0] == '^';
+ case Kind::exclamation_identifier:
+ return spelling[0] == '!';
+ default:
+ return false;
+ }
+}
diff --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h
index be0924cb9d67a..0e48805bb06f0 100644
--- a/mlir/lib/Parser/Token.h
+++ b/mlir/lib/Parser/Token.h
@@ -58,6 +58,19 @@ class Token {
/// Return true if this is one of the keyword token kinds (e.g. kw_if).
bool isKeyword() const;
+ /// Returns true if the current token represents a code completion.
+ bool isCodeCompletion() const { return is(code_complete); }
+
+ /// Returns true if the current token represents a code completion for the
+ /// "normal" token type.
+ bool isCodeCompletionFor(Kind kind) const;
+
+ /// Returns true if the current token is the given type, or represents a code
+ /// completion for that type.
+ bool isOrIsCodeCompletionFor(Kind kind) const {
+ return is(kind) || isCodeCompletionFor(kind);
+ }
+
// Helpers to decode specific sorts of tokens.
/// For an integer token, return its value as an unsigned. If it doesn't fit,
diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def
index f7146674234f6..207af3871f8d3 100644
--- a/mlir/lib/Parser/TokenKinds.def
+++ b/mlir/lib/Parser/TokenKinds.def
@@ -36,6 +36,7 @@
// Markers
TOK_MARKER(eof)
TOK_MARKER(error)
+TOK_MARKER(code_complete)
// Identifiers.
TOK_IDENTIFIER(bare_identifier) // foo
diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.h b/mlir/lib/Tools/lsp-server-support/Protocol.h
index fdbbae2edfa63..03fe3ca6e41e7 100644
--- a/mlir/lib/Tools/lsp-server-support/Protocol.h
+++ b/mlir/lib/Tools/lsp-server-support/Protocol.h
@@ -780,6 +780,11 @@ enum class InsertTextFormat {
};
struct CompletionItem {
+ CompletionItem() = default;
+ CompletionItem(StringRef label, CompletionItemKind kind)
+ : label(label.str()), kind(kind),
+ insertTextFormat(InsertTextFormat::PlainText) {}
+
/// The label of this completion item. By default also the text that is
/// inserted when selecting this completion.
std::string label;
diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
index e5c24443a9efb..a3b86af4249d2 100644
--- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
@@ -62,6 +62,12 @@ struct LSPServer::Impl {
void onDocumentSymbol(const DocumentSymbolParams ¶ms,
Callback<std::vector<DocumentSymbol>> reply);
+ //===--------------------------------------------------------------------===//
+ // Code Completion
+
+ void onCompletion(const CompletionParams ¶ms,
+ Callback<CompletionList> reply);
+
//===--------------------------------------------------------------------===//
// Fields
//===--------------------------------------------------------------------===//
@@ -91,6 +97,15 @@ void LSPServer::Impl::onInitialize(const InitializeParams ¶ms,
{"change", (int)TextDocumentSyncKind::Full},
{"save", true},
}},
+ {"completionProvider",
+ llvm::json::Object{
+ {"allCommitCharacters",
+ {"\t", "(", ")", "[", "]", "<", ">", ";", ",", "+", "-", "/", "*",
+ "&", "?", ".", "=", "|"}},
+ {"resolveProvider", false},
+ {"triggerCharacters",
+ {".", "%", "^", "!", "#", "(", ",", "<", ":", "[", " ", "\"", "/"}},
+ }},
{"definitionProvider", true},
{"referencesProvider", true},
{"hoverProvider", true},
@@ -192,6 +207,14 @@ void LSPServer::Impl::onDocumentSymbol(
reply(std::move(symbols));
}
+//===----------------------------------------------------------------------===//
+// Code Completion
+
+void LSPServer::Impl::onCompletion(const CompletionParams ¶ms,
+ Callback<CompletionList> reply) {
+ reply(server.getCodeCompletion(params.textDocument.uri, params.position));
+}
+
//===----------------------------------------------------------------------===//
// LSPServer
//===----------------------------------------------------------------------===//
@@ -229,6 +252,10 @@ LogicalResult LSPServer::run() {
messageHandler.method("textDocument/documentSymbol", impl.get(),
&Impl::onDocumentSymbol);
+ // Code Completion
+ messageHandler.method("textDocument/completion", impl.get(),
+ &Impl::onCompletion);
+
// Diagnostics
impl->publishDiagnostics =
messageHandler.outgoingNotification<PublishDiagnosticsParams>(
diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
index f5671692e6a82..c2097747d8452 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/Operation.h"
#include "mlir/Parser/AsmParserState.h"
+#include "mlir/Parser/CodeComplete.h"
#include "mlir/Parser/Parser.h"
#include "llvm/Support/SourceMgr.h"
@@ -276,6 +277,14 @@ struct MLIRDocument {
void findDocumentSymbols(Operation *op,
std::vector<lsp::DocumentSymbol> &symbols);
+ //===--------------------------------------------------------------------===//
+ // Code Completion
+ //===--------------------------------------------------------------------===//
+
+ lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
+ const lsp::Position &completePos,
+ const DialectRegistry ®istry);
+
//===--------------------------------------------------------------------===//
// Fields
//===--------------------------------------------------------------------===//
@@ -615,6 +624,102 @@ void MLIRDocument::findDocumentSymbols(
findDocumentSymbols(&childOp, *childSymbols);
}
+//===----------------------------------------------------------------------===//
+// MLIRDocument: Code Completion
+//===----------------------------------------------------------------------===//
+
+namespace {
+class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
+public:
+ LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
+ MLIRContext *ctx)
+ : AsmParserCodeCompleteContext(completeLoc),
+ completionList(completionList), ctx(ctx) {}
+
+ /// Signal code completion for a dialect name.
+ void completeDialectName() final {
+ for (StringRef dialect : ctx->getAvailableDialects()) {
+ lsp::CompletionItem item(dialect, lsp::CompletionItemKind::Module);
+ item.sortText = "2";
+ item.detail = "dialect";
+ completionList.items.emplace_back(item);
+ }
+ }
+
+ /// Signal code completion for an operation name within the given dialect.
+ void completeOperationName(StringRef dialectName) final {
+ Dialect *dialect = ctx->getOrLoadDialect(dialectName);
+ if (!dialect)
+ return;
+
+ for (const auto &op : ctx->getRegisteredOperations()) {
+ if (&op.getDialect() != dialect)
+ continue;
+
+ lsp::CompletionItem item(
+ op.getStringRef().drop_front(dialectName.size() + 1),
+ lsp::CompletionItemKind::Field);
+ item.sortText = "1";
+ item.detail = "operation";
+ completionList.items.emplace_back(item);
+ }
+ }
+
+ /// Append the given SSA value as a code completion result for SSA value
+ /// completions.
+ void appendSSAValueCompletion(StringRef name, std::string typeData) final {
+ // Check if we need to insert the `%` or not.
+ bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
+
+ lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable);
+ if (stripPrefix)
+ item.insertText = name.drop_front(1).str();
+ item.detail = std::move(typeData);
+ completionList.items.emplace_back(item);
+ }
+
+ /// Append the given block as a code completion result for block name
+ /// completions.
+ void appendBlockCompletion(StringRef name) final {
+ // Check if we need to insert the `^` or not.
+ bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
+
+ lsp::CompletionItem item(name, lsp::CompletionItemKind::Field);
+ if (stripPrefix)
+ item.insertText = name.drop_front(1).str();
+ completionList.items.emplace_back(item);
+ }
+
+private:
+ lsp::CompletionList &completionList;
+ MLIRContext *ctx;
+};
+} // namespace
+
+lsp::CompletionList
+MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
+ const lsp::Position &completePos,
+ const DialectRegistry ®istry) {
+ SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
+ if (!posLoc.isValid())
+ return lsp::CompletionList();
+
+ // To perform code completion, we run another parse of the module with the
+ // code completion context provided.
+ MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
+ tmpContext.allowUnregisteredDialects();
+ lsp::CompletionList completionList;
+ LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
+ &tmpContext);
+
+ Block tmpIR;
+ AsmParserState tmpState;
+ (void)parseSourceFile(sourceMgr, &tmpIR, &tmpContext,
+ /*sourceFileLoc=*/nullptr, &tmpState,
+ &lspCompleteContext);
+ return completionList;
+}
+
//===----------------------------------------------------------------------===//
// MLIRTextFileChunk
//===----------------------------------------------------------------------===//
@@ -670,6 +775,8 @@ class MLIRTextFile {
Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
lsp::Position hoverPos);
void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
+ lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
+ lsp::Position completePos);
private:
/// Find the MLIR document that contains the given position, and update the
@@ -813,6 +920,22 @@ void MLIRTextFile::findDocumentSymbols(
}
}
+lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
+ lsp::Position completePos) {
+ MLIRTextFileChunk &chunk = getChunkFor(completePos);
+ lsp::CompletionList completionList = chunk.document.getCodeCompletion(
+ uri, completePos, context.getDialectRegistry());
+
+ // Adjust any completion locations.
+ for (lsp::CompletionItem &item : completionList.items) {
+ if (item.textEdit)
+ chunk.adjustLocForChunkOffset(item.textEdit->range);
+ for (lsp::TextEdit &edit : item.additionalTextEdits)
+ chunk.adjustLocForChunkOffset(edit.range);
+ }
+ return completionList;
+}
+
MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
if (chunks.size() == 1)
return *chunks.front();
@@ -898,3 +1021,12 @@ void lsp::MLIRServer::findDocumentSymbols(
if (fileIt != impl->files.end())
fileIt->second->findDocumentSymbols(symbols);
}
+
+lsp::CompletionList
+lsp::MLIRServer::getCodeCompletion(const URIForFile &uri,
+ const Position &completePos) {
+ auto fileIt = impl->files.find(uri.file());
+ if (fileIt != impl->files.end())
+ return fileIt->second->getCodeCompletion(uri, completePos);
+ return CompletionList();
+}
diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h
index d78d8c3dc6ec8..85ccf0a68f1b0 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h
@@ -16,6 +16,7 @@ namespace mlir {
class DialectRegistry;
namespace lsp {
+struct CompletionList;
struct Diagnostic;
struct DocumentSymbol;
struct Hover;
@@ -60,6 +61,10 @@ class MLIRServer {
void findDocumentSymbols(const URIForFile &uri,
std::vector<DocumentSymbol> &symbols);
+ /// Get the code completion list for the position within the given file.
+ CompletionList getCodeCompletion(const URIForFile &uri,
+ const Position &completePos);
+
private:
struct Impl;
diff --git a/mlir/test/mlir-lsp-server/completion.test b/mlir/test/mlir-lsp-server/completion.test
new file mode 100644
index 0000000000000..b13dbe147f057
--- /dev/null
+++ b/mlir/test/mlir-lsp-server/completion.test
@@ -0,0 +1,105 @@
+// RUN: mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s
+{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"mlir","capabilities":{},"trace":"off"}}
+// -----
+{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
+ "uri":"test:///foo.mlir",
+ "languageId":"mlir",
+ "version":1,
+ "text":"func.func private @foo(%arg: i32) -> i32 {\n%cast = \"builtin.unrealized_conversion_cast\"() : () -> (i32)\nreturn %"
+}}}
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+ "textDocument":{"uri":"test:///foo.mlir"},
+ "position":{"line":0,"character":0}
+}}
+// CHECK: "id": 1
+// CHECK-NEXT: "jsonrpc": "2.0",
+// CHECK-NEXT: "result": {
+// CHECK-NEXT: "isIncomplete": false,
+// CHECK-NEXT: "items": [
+// CHECK: {
+// CHECK: "detail": "dialect",
+// CHECK: "insertTextFormat": 1,
+// CHECK: "kind": 9,
+// CHECK: "label": "builtin",
+// CHECK: "sortText": "2"
+// CHECK: },
+// CHECK: {
+// CHECK: "detail": "operation",
+// CHECK: "insertTextFormat": 1,
+// CHECK: "kind": 5,
+// CHECK: "label": "module",
+// CHECK: "sortText": "1"
+// CHECK: }
+// CHECK: ]
+// CHECK-NEXT: }
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+ "textDocument":{"uri":"test:///foo.mlir"},
+ "position":{"line":1,"character":9}
+}}
+// CHECK: "id": 1
+// CHECK-NEXT: "jsonrpc": "2.0",
+// CHECK-NEXT: "result": {
+// CHECK-NEXT: "isIncomplete": false,
+// CHECK-NEXT: "items": [
+// CHECK: {
+// CHECK: "detail": "dialect",
+// CHECK: "insertTextFormat": 1,
+// CHECK: "kind": 9,
+// CHECK: "label": "builtin",
+// CHECK: "sortText": "2"
+// CHECK: },
+// CHECK-NOT: "detail": "operation",
+// CHECK: ]
+// CHECK-NEXT: }
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+ "textDocument":{"uri":"test:///foo.mlir"},
+ "position":{"line":1,"character":17}
+}}
+// CHECK: "id": 1
+// CHECK-NEXT: "jsonrpc": "2.0",
+// CHECK-NEXT: "result": {
+// CHECK-NEXT: "isIncomplete": false,
+// CHECK-NEXT: "items": [
+// CHECK-NOT: "detail": "dialect",
+// CHECK: {
+// CHECK: "detail": "operation",
+// CHECK: "insertTextFormat": 1,
+// CHECK: "kind": 5,
+// CHECK: "label": "module",
+// CHECK: "sortText": "1"
+// CHECK: }
+// CHECK: ]
+// CHECK-NEXT: }
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+ "textDocument":{"uri":"test:///foo.mlir"},
+ "position":{"line":2,"character":8}
+}}
+// CHECK: "id": 1
+// CHECK-NEXT: "jsonrpc": "2.0",
+// CHECK-NEXT: "result": {
+// CHECK-NEXT: "isIncomplete": false,
+// CHECK-NEXT: "items": [
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "builtin.unrealized_conversion_cast: i32",
+// CHECK-NEXT: "insertText": "cast",
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 6,
+// CHECK-NEXT: "label": "%cast"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "arg #0: i32",
+// CHECK-NEXT: "insertText": "arg",
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 6,
+// CHECK-NEXT: "label": "%arg"
+// CHECK-NEXT: }
+// CHECK: ]
+// CHECK-NEXT: }
+// -----
+{"jsonrpc":"2.0","id":3,"method":"shutdown"}
+// -----
+{"jsonrpc":"2.0","method":"exit"}
diff --git a/mlir/test/mlir-lsp-server/initialize-params.test b/mlir/test/mlir-lsp-server/initialize-params.test
index db41a61a8a1a8..2fe528f8ed2f2 100644
--- a/mlir/test/mlir-lsp-server/initialize-params.test
+++ b/mlir/test/mlir-lsp-server/initialize-params.test
@@ -5,6 +5,13 @@
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "capabilities": {
+// CHECK-NEXT: "completionProvider": {
+// CHECK-NEXT: "allCommitCharacters": [
+// CHECK: ],
+// CHECK-NEXT: "resolveProvider": false,
+// CHECK-NEXT: "triggerCharacters": [
+// CHECK: ]
+// CHECK-NEXT: },
// CHECK-NEXT: "definitionProvider": true,
// CHECK-NEXT: "documentSymbolProvider": false,
// CHECK-NEXT: "hoverProvider": true,
More information about the Mlir-commits
mailing list