[Mlir-commits] [mlir] ed2fb17 - [mlir:LSP] Add support for MLIR code completions

River Riddle llvmlistbot at llvm.org
Thu Jul 7 13:36:30 PDT 2022


Author: River Riddle
Date: 2022-07-07T13:35:54-07:00
New Revision: ed2fb1736ac1515838d136b57e5ad9a5c805001b

URL: https://github.com/llvm/llvm-project/commit/ed2fb1736ac1515838d136b57e5ad9a5c805001b
DIFF: https://github.com/llvm/llvm-project/commit/ed2fb1736ac1515838d136b57e5ad9a5c805001b.diff

LOG: [mlir:LSP] Add support for MLIR code completions

This commit adds code completion results to the MLIR LSP using
a new code completion context in the MLIR parser. This commit
adds initial completion for dialect, operation, SSA value, and
block names.

Differential Revision: https://reviews.llvm.org/D129183

Added: 
    mlir/include/mlir/Parser/CodeComplete.h
    mlir/test/mlir-lsp-server/completion.test

Modified: 
    mlir/include/mlir/Parser/Parser.h
    mlir/lib/Parser/AffineParser.cpp
    mlir/lib/Parser/DialectSymbolParser.cpp
    mlir/lib/Parser/Lexer.cpp
    mlir/lib/Parser/Lexer.h
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Parser/Parser.h
    mlir/lib/Parser/ParserState.h
    mlir/lib/Parser/Token.cpp
    mlir/lib/Parser/Token.h
    mlir/lib/Parser/TokenKinds.def
    mlir/lib/Tools/lsp-server-support/Protocol.h
    mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
    mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
    mlir/lib/Tools/mlir-lsp-server/MLIRServer.h
    mlir/test/mlir-lsp-server/initialize-params.test

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Parser/CodeComplete.h b/mlir/include/mlir/Parser/CodeComplete.h
new file mode 100644
index 0000000000000..24d0181456f67
--- /dev/null
+++ b/mlir/include/mlir/Parser/CodeComplete.h
@@ -0,0 +1,58 @@
+//===- CodeComplete.h - MLIR Asm CodeComplete Context -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_PARSER_CODECOMPLETE_H
+#define MLIR_PARSER_CODECOMPLETE_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir {
+/// This class provides an abstract interface into the parser for hooking in
+/// code completion events. This class is only really useful for providing
+/// language tooling for MLIR, general clients should not need to use this
+/// class.
+class AsmParserCodeCompleteContext {
+public:
+  virtual ~AsmParserCodeCompleteContext();
+
+  /// Return the source location used to provide code completion.
+  SMLoc getCodeCompleteLoc() const { return codeCompleteLoc; }
+
+  //===--------------------------------------------------------------------===//
+  // Completion Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// Signal code completion for a dialect name.
+  virtual void completeDialectName() = 0;
+
+  /// Signal code completion for an operation name within the given dialect.
+  virtual void completeOperationName(StringRef dialectName) = 0;
+
+  /// Append the given SSA value as a code completion result for SSA value
+  /// completions.
+  virtual void appendSSAValueCompletion(StringRef name,
+                                        std::string typeData) = 0;
+
+  /// Append the given block as a code completion result for block name
+  /// completions.
+  virtual void appendBlockCompletion(StringRef name) = 0;
+
+protected:
+  /// Create a new code completion context with the given code complete
+  /// location.
+  explicit AsmParserCodeCompleteContext(SMLoc codeCompleteLoc)
+      : codeCompleteLoc(codeCompleteLoc) {}
+
+private:
+  /// The location used to code complete.
+  SMLoc codeCompleteLoc;
+};
+} // namespace mlir
+
+#endif // MLIR_PARSER_CODECOMPLETE_H

diff  --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h
index 69f02c45d6fbb..5c9543f187d6a 100644
--- a/mlir/include/mlir/Parser/Parser.h
+++ b/mlir/include/mlir/Parser/Parser.h
@@ -26,6 +26,7 @@ class StringRef;
 
 namespace mlir {
 class AsmParserState;
+class AsmParserCodeCompleteContext;
 
 namespace detail {
 
@@ -83,11 +84,13 @@ inline OwningOpRef<ContainerOpT> constructContainerOpForParserIfNecessary(
 /// source file that is being parsed. If `asmState` is non-null, it is populated
 /// with detailed information about the parsed IR (including exact locations for
 /// SSA uses and definitions). `asmState` should only be provided if this
-/// detailed information is desired.
-LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
-                              const ParserConfig &config,
-                              LocationAttr *sourceFileLoc = nullptr,
-                              AsmParserState *asmState = nullptr);
+/// detailed information is desired. If `codeCompleteContext` is non-null, it is
+/// used to signal tracking of a code completion event (generally only ever
+/// useful for LSP or other high level language tooling).
+LogicalResult parseSourceFile(
+    const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config,
+    LocationAttr *sourceFileLoc = nullptr, AsmParserState *asmState = nullptr,
+    AsmParserCodeCompleteContext *codeCompleteContext = nullptr);
 
 /// This parses the file specified by the indicated filename and appends parsed
 /// operations to the given block. If the block is non-empty, the operations are

diff  --git a/mlir/lib/Parser/AffineParser.cpp b/mlir/lib/Parser/AffineParser.cpp
index 7fa054bf3913a..345b3a505b069 100644
--- a/mlir/lib/Parser/AffineParser.cpp
+++ b/mlir/lib/Parser/AffineParser.cpp
@@ -743,7 +743,8 @@ IntegerSet mlir::parseIntegerSet(StringRef inputStr, MLIRContext *context,
   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
   SymbolState symbolState;
   ParserConfig config(context);
-  ParserState state(sourceMgr, config, symbolState, /*asmState=*/nullptr);
+  ParserState state(sourceMgr, config, symbolState, /*asmState=*/nullptr,
+                    /*codeCompleteContext=*/nullptr);
   Parser parser(state);
 
   raw_ostream &os = printDiagnosticInfo ? llvm::errs() : llvm::nulls();

diff  --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 623ea8ef12ea1..4478d1936e661 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -277,7 +277,8 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
   SymbolState aliasState;
   ParserConfig config(context);
-  ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr);
+  ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
+                    /*codeCompleteContext=*/nullptr);
   Parser parser(state);
 
   SourceMgrDiagnosticHandler handler(

diff  --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp
index f3e4ce6b2768e..cbebaf24159d4 100644
--- a/mlir/lib/Parser/Lexer.cpp
+++ b/mlir/lib/Parser/Lexer.cpp
@@ -14,6 +14,7 @@
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/Parser/CodeComplete.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/SourceMgr.h"
@@ -26,11 +27,16 @@ static bool isPunct(char c) {
   return c == '$' || c == '.' || c == '_' || c == '-';
 }
 
-Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context)
-    : sourceMgr(sourceMgr), context(context) {
+Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context,
+             AsmParserCodeCompleteContext *codeCompleteContext)
+    : sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) {
   auto bufferID = sourceMgr.getMainFileID();
   curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
   curPtr = curBuffer.begin();
+
+  // Set the code completion location if it was provided.
+  if (codeCompleteContext)
+    codeCompleteLoc = codeCompleteContext->getCodeCompleteLoc().getPointer();
 }
 
 /// Encode the specified source location information into an attribute for
@@ -61,6 +67,12 @@ Token Lexer::emitError(const char *loc, const Twine &message) {
 Token Lexer::lexToken() {
   while (true) {
     const char *tokStart = curPtr;
+
+    // Check to see if the current token is at the code completion location.
+    if (tokStart == codeCompleteLoc)
+      return formToken(Token::code_complete, tokStart);
+
+    // Lex the next token.
     switch (*curPtr++) {
     default:
       // Handle bare identifiers.
@@ -357,17 +369,25 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) {
   // Parse suffix-id.
   if (isdigit(*curPtr)) {
     // If suffix-id starts with a digit, the rest must be digits.
-    while (isdigit(*curPtr)) {
+    while (isdigit(*curPtr))
       ++curPtr;
-    }
   } else if (isalpha(*curPtr) || isPunct(*curPtr)) {
     do {
       ++curPtr;
     } while (isalpha(*curPtr) || isdigit(*curPtr) || isPunct(*curPtr));
+  } else if (curPtr == codeCompleteLoc) {
+    return formToken(Token::code_complete, tokStart);
   } else {
     return emitError(curPtr - 1, errorKind);
   }
 
+  // Check for a code completion within the identifier.
+  if (codeCompleteLoc && codeCompleteLoc >= tokStart &&
+      codeCompleteLoc <= curPtr) {
+    return Token(Token::code_complete,
+                 StringRef(tokStart, codeCompleteLoc - tokStart));
+  }
+
   return formToken(kind, tokStart);
 }
 
@@ -380,6 +400,13 @@ Token Lexer::lexString(const char *tokStart) {
   assert(curPtr[-1] == '"');
 
   while (true) {
+    // Check to see if there is a code completion location within the string. In
+    // these cases we generate a completion location and place the currently
+    // lexed string within the token. This allows for the parser to use the
+    // partially lexed string when computing the completion results.
+    if (curPtr == codeCompleteLoc)
+      return formToken(Token::code_complete, tokStart);
+
     switch (*curPtr++) {
     case '"':
       return formToken(Token::string, tokStart);

diff  --git a/mlir/lib/Parser/Lexer.h b/mlir/lib/Parser/Lexer.h
index 58b5e2321c4d3..e09ae168e3bba 100644
--- a/mlir/lib/Parser/Lexer.h
+++ b/mlir/lib/Parser/Lexer.h
@@ -22,7 +22,8 @@ class Location;
 /// This class breaks up the current file into a token stream.
 class Lexer {
 public:
-  explicit Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context);
+  explicit Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context,
+                 AsmParserCodeCompleteContext *codeCompleteContext);
 
   const llvm::SourceMgr &getSourceMgr() { return sourceMgr; }
 
@@ -64,6 +65,10 @@ class Lexer {
   StringRef curBuffer;
   const char *curPtr;
 
+  /// An optional code completion point within the input file, used to indicate
+  /// the position of a code completion token.
+  const char *codeCompleteLoc;
+
   Lexer(const Lexer &) = delete;
   void operator=(const Lexer &) = delete;
 };

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 5bc1e43c60a0e..51a65a99dfa7a 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Parser/AsmParserState.h"
+#include "mlir/Parser/CodeComplete.h"
 #include "mlir/Parser/Parser.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/ScopeExit.h"
@@ -32,6 +33,12 @@ using namespace mlir::detail;
 using llvm::MemoryBuffer;
 using llvm::SourceMgr;
 
+//===----------------------------------------------------------------------===//
+// CodeComplete
+//===----------------------------------------------------------------------===//
+
+AsmParserCodeCompleteContext::~AsmParserCodeCompleteContext() = default;
+
 //===----------------------------------------------------------------------===//
 // Parser
 //===----------------------------------------------------------------------===//
@@ -334,6 +341,60 @@ Parser::parseResourceHandle(const OpAsmDialectInterface *dialect,
   return entry.second;
 }
 
+//===----------------------------------------------------------------------===//
+// Code Completion
+
+ParseResult Parser::codeCompleteDialectName() {
+  state.codeCompleteContext->completeDialectName();
+  return failure();
+}
+
+ParseResult Parser::codeCompleteOperationName(StringRef dialectName) {
+  // Perform some simple validation on the dialect name. This doesn't need to be
+  // extensive, it's more of an optimization (to avoid checking completion
+  // results when we know they will fail).
+  if (dialectName.empty() || dialectName.contains('.'))
+    return failure();
+  state.codeCompleteContext->completeOperationName(dialectName);
+  return failure();
+}
+
+ParseResult Parser::codeCompleteDialectOrElidedOpName(SMLoc loc) {
+  // Check to see if there is anything else on the current line. This check
+  // isn't strictly necessary, but it does avoid unnecessarily triggering
+  // completions for operations and dialects in situations where we don't want
+  // them (e.g. at the end of an operation).
+  auto shouldIgnoreOpCompletion = [&]() {
+    const char *bufBegin = state.lex.getBufferBegin();
+    const char *it = loc.getPointer() - 1;
+    for (; it > bufBegin && *it != '\n'; --it)
+      if (!llvm::is_contained(StringRef(" \t\r"), *it))
+        return true;
+    return false;
+  };
+  if (shouldIgnoreOpCompletion())
+    return failure();
+
+  // The completion here is either for a dialect name, or an operation name
+  // whose dialect prefix was elided. For this we simply invoke both of the
+  // individual completion methods.
+  (void)codeCompleteDialectName();
+  return codeCompleteOperationName(state.defaultDialectStack.back());
+}
+
+ParseResult Parser::codeCompleteStringDialectOrOperationName(StringRef name) {
+  // If the name is empty, this is the start of the string and contains the
+  // dialect.
+  if (name.empty())
+    return codeCompleteDialectName();
+
+  // Otherwise, we treat this as completing an operation name. The current name
+  // is used as the dialect namespace.
+  if (name.consume_back("."))
+    return codeCompleteOperationName(name);
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // OperationParser
 //===----------------------------------------------------------------------===//
@@ -497,6 +558,17 @@ class OperationParser : public Parser {
   /// us to diagnose references to blocks that are not defined precisely.
   Block *getBlockNamed(StringRef name, SMLoc loc);
 
+  //===--------------------------------------------------------------------===//
+  // Code Completion
+  //===--------------------------------------------------------------------===//
+
+  /// The set of various code completion methods. Every completion method
+  /// returns `failure` to stop the parsing process after providing completion
+  /// results.
+
+  ParseResult codeCompleteSSAUse();
+  ParseResult codeCompleteBlock();
+
 private:
   /// This class represents a definition of a Block.
   struct BlockDefinition {
@@ -790,7 +862,7 @@ ParseResult OperationParser::addDefinition(UnresolvedOperand useInfo,
 ///
 ParseResult OperationParser::parseOptionalSSAUseList(
     SmallVectorImpl<UnresolvedOperand> &results) {
-  if (getToken().isNot(Token::percent_identifier))
+  if (!getToken().isOrIsCodeCompletionFor(Token::percent_identifier))
     return success();
   return parseCommaSeparatedList([&]() -> ParseResult {
     UnresolvedOperand result;
@@ -807,6 +879,9 @@ ParseResult OperationParser::parseOptionalSSAUseList(
 ///
 ParseResult OperationParser::parseSSAUse(UnresolvedOperand &result,
                                          bool allowResultNumber) {
+  if (getToken().isCodeCompletion())
+    return codeCompleteSSAUse();
+
   result.name = getTokenSpelling();
   result.number = 0;
   result.location = getToken().getLoc();
@@ -1017,6 +1092,10 @@ ParseResult OperationParser::parseOperation() {
     op = parseCustomOperation(resultIDs);
   else if (nameTok.is(Token::string))
     op = parseGenericOperation();
+  else if (nameTok.isCodeCompletionFor(Token::string))
+    return codeCompleteStringDialectOrOperationName(nameTok.getStringValue());
+  else if (nameTok.isCodeCompletion())
+    return codeCompleteDialectOrElidedOpName(loc);
   else
     return emitWrongTokenError("expected operation name in quotes");
 
@@ -1071,6 +1150,9 @@ ParseResult OperationParser::parseOperation() {
 ///   successor ::= block-id
 ///
 ParseResult OperationParser::parseSuccessor(Block *&dest) {
+  if (getToken().isCodeCompletion())
+    return codeCompleteBlock();
+
   // Verify branch is identifier and get the matching block.
   if (!getToken().is(Token::caret_identifier))
     return emitWrongTokenError("expected block name");
@@ -1391,7 +1473,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
   OptionalParseResult
   parseOptionalOperand(UnresolvedOperand &result,
                        bool allowResultNumber = true) override {
-    if (parser.getToken().is(Token::percent_identifier))
+    if (parser.getToken().isOrIsCodeCompletionFor(Token::percent_identifier))
       return parseOperand(result, allowResultNumber);
     return llvm::None;
   }
@@ -1406,14 +1488,15 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
     if (delimiter == Delimiter::None) {
       // parseCommaSeparatedList doesn't handle the missing case for "none",
       // so we handle it custom here.
-      if (parser.getToken().isNot(Token::percent_identifier)) {
+      Token tok = parser.getToken();
+      if (!tok.isOrIsCodeCompletionFor(Token::percent_identifier)) {
         // If we didn't require any operands or required exactly zero (weird)
         // then this is success.
         if (requiredOperandCount == -1 || requiredOperandCount == 0)
           return success();
 
         // Otherwise, try to produce a nice error message.
-        if (parser.getToken().isAny(Token::l_paren, Token::l_square))
+        if (tok.isAny(Token::l_paren, Token::l_square))
           return parser.emitError("unexpected delimiter");
         return parser.emitWrongTokenError("expected operand");
       }
@@ -1597,7 +1680,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
 
   /// Parse an optional operation successor and its operand list.
   OptionalParseResult parseOptionalSuccessor(Block *&dest) override {
-    if (parser.getToken().isNot(Token::caret_identifier))
+    if (!parser.getToken().isOrIsCodeCompletionFor(Token::caret_identifier))
       return llvm::None;
     return parseSuccessor(dest);
   }
@@ -1681,7 +1764,8 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
 } // namespace
 
 FailureOr<OperationName> OperationParser::parseCustomOperationName() {
-  std::string opName = getTokenSpelling().str();
+  Token nameTok = getToken();
+  StringRef opName = nameTok.getSpelling();
   if (opName.empty())
     return (emitError("empty operation name is invalid"), failure());
   consumeToken();
@@ -1694,11 +1778,17 @@ FailureOr<OperationName> OperationParser::parseCustomOperationName() {
 
   // If the operation doesn't have a dialect prefix try using the default
   // dialect.
-  auto opNameSplit = StringRef(opName).split('.');
+  auto opNameSplit = opName.split('.');
   StringRef dialectName = opNameSplit.first;
+  std::string opNameStorage;
   if (opNameSplit.second.empty()) {
+    // If the name didn't have a prefix, check for a code completion request.
+    if (getToken().isCodeCompletion() && opName.back() == '.')
+      return codeCompleteOperationName(dialectName);
+
     dialectName = getState().defaultDialectStack.back();
-    opName = (dialectName + "." + opName).str();
+    opNameStorage = (dialectName + "." + opName).str();
+    opName = opNameStorage;
   }
 
   // Try to load the dialect before returning the operation name to make sure
@@ -2091,6 +2181,58 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
   });
 }
 
+//===----------------------------------------------------------------------===//
+// Code Completion
+//===----------------------------------------------------------------------===//
+
+ParseResult OperationParser::codeCompleteSSAUse() {
+  std::string detailData;
+  llvm::raw_string_ostream detailOS(detailData);
+  for (IsolatedSSANameScope &scope : isolatedNameScopes) {
+    for (auto &it : scope.values) {
+      if (it.second.empty())
+        continue;
+      Value frontValue = it.second.front().value;
+
+      // If the value isn't a forward reference, we also add the name of the op
+      // to the detail.
+      if (auto result = frontValue.dyn_cast<OpResult>()) {
+        if (!forwardRefPlaceholders.count(result))
+          detailOS << result.getOwner()->getName() << ": ";
+      } else {
+        detailOS << "arg #" << frontValue.cast<BlockArgument>().getArgNumber()
+                 << ": ";
+      }
+
+      // Emit the type of the values to aid with completion selection.
+      detailOS << frontValue.getType();
+
+      // FIXME: We should define a policy for packed values, e.g. with a limit
+      // on the detail size, but it isn't clear what would be useful right now.
+      // For now we just only emit the first type.
+      if (it.second.size() > 1)
+        detailOS << ", ...";
+
+      state.codeCompleteContext->appendSSAValueCompletion(
+          it.getKey(), std::move(detailOS.str()));
+    }
+  }
+
+  return failure();
+}
+
+ParseResult OperationParser::codeCompleteBlock() {
+  // Don't provide completions if the token isn't empty, e.g. this avoids
+  // weirdness when we encounter a `.` within the identifier.
+  StringRef spelling = getTokenSpelling();
+  if (!(spelling.empty() || spelling == "^"))
+    return failure();
+
+  for (const auto &it : blocksByName.back())
+    state.codeCompleteContext->appendBlockCompletion(it.getFirst());
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // Top-level entity parsing.
 //===----------------------------------------------------------------------===//
@@ -2418,10 +2560,11 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
 
 //===----------------------------------------------------------------------===//
 
-LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
-                                    Block *block, const ParserConfig &config,
-                                    LocationAttr *sourceFileLoc,
-                                    AsmParserState *asmState) {
+LogicalResult
+mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
+                      const ParserConfig &config, LocationAttr *sourceFileLoc,
+                      AsmParserState *asmState,
+                      AsmParserCodeCompleteContext *codeCompleteContext) {
   const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
 
   Location parserLoc =
@@ -2431,7 +2574,8 @@ LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
     *sourceFileLoc = parserLoc;
 
   SymbolState aliasState;
-  ParserState state(sourceMgr, config, aliasState, asmState);
+  ParserState state(sourceMgr, config, aliasState, asmState,
+                    codeCompleteContext);
   return TopLevelOperationParser(state).parse(block, parserLoc);
 }
 

diff  --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index e762fc6d3543b..f1c87ca4916b5 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -306,6 +306,20 @@ class Parser {
   parseAffineExprOfSSAIds(AffineExpr &expr,
                           function_ref<ParseResult(bool)> parseElement);
 
+  //===--------------------------------------------------------------------===//
+  // Code Completion
+  //===--------------------------------------------------------------------===//
+
+  /// The set of various code completion methods. Every completion method
+  /// returns `failure` to signal that parsing should abort after any desired
+  /// completions have been enqueued. Note that `failure` is does not mean
+  /// completion failed, it's just a signal to the parser to stop.
+
+  ParseResult codeCompleteDialectName();
+  ParseResult codeCompleteOperationName(StringRef dialectName);
+  ParseResult codeCompleteDialectOrElidedOpName(SMLoc loc);
+  ParseResult codeCompleteStringDialectOrOperationName(StringRef name);
+
 protected:
   /// The Parser is subclassed and reinstantiated.  Do not add additional
   /// non-trivial state here, add it to the ParserState class.

diff  --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h
index 4d3ff2774f8e4..ec64a43c7bb8b 100644
--- a/mlir/lib/Parser/ParserState.h
+++ b/mlir/lib/Parser/ParserState.h
@@ -43,9 +43,12 @@ struct SymbolState {
 /// such as the current lexer position etc.
 struct ParserState {
   ParserState(const llvm::SourceMgr &sourceMgr, const ParserConfig &config,
-              SymbolState &symbols, AsmParserState *asmState)
-      : config(config), lex(sourceMgr, config.getContext()),
-        curToken(lex.lexToken()), symbols(symbols), asmState(asmState) {}
+              SymbolState &symbols, AsmParserState *asmState,
+              AsmParserCodeCompleteContext *codeCompleteContext)
+      : config(config),
+        lex(sourceMgr, config.getContext(), codeCompleteContext),
+        curToken(lex.lexToken()), symbols(symbols), asmState(asmState),
+        codeCompleteContext(codeCompleteContext) {}
   ParserState(const ParserState &) = delete;
   void operator=(const ParserState &) = delete;
 
@@ -65,6 +68,9 @@ struct ParserState {
   /// populated during parsing.
   AsmParserState *asmState;
 
+  /// An optional code completion context.
+  AsmParserCodeCompleteContext *codeCompleteContext;
+
   // Contains the stack of default dialect to use when parsing regions.
   // A new dialect get pushed to the stack before parsing regions nested
   // under an operation implementing `OpAsmOpInterface`, and

diff  --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp
index 23410b37088d4..7a8d2975d2f47 100644
--- a/mlir/lib/Parser/Token.cpp
+++ b/mlir/lib/Parser/Token.cpp
@@ -78,12 +78,15 @@ Optional<bool> Token::getIntTypeSignedness() const {
 /// removing the quote characters and unescaping the contents of the string. The
 /// lexer has already verified that this token is valid.
 std::string Token::getStringValue() const {
-  assert(getKind() == string ||
+  assert(getKind() == string || getKind() == code_complete ||
          (getKind() == at_identifier && getSpelling()[1] == '"'));
   // Start by dropping the quotes.
-  StringRef bytes = getSpelling().drop_front().drop_back();
-  if (getKind() == at_identifier)
-    bytes = bytes.drop_front();
+  StringRef bytes = getSpelling().drop_front();
+  if (getKind() != Token::code_complete) {
+    bytes = bytes.drop_back();
+    if (getKind() == at_identifier)
+      bytes = bytes.drop_front();
+  }
 
   std::string result;
   result.reserve(bytes.size());
@@ -190,3 +193,22 @@ bool Token::isKeyword() const {
 #include "TokenKinds.def"
   }
 }
+
+bool Token::isCodeCompletionFor(Kind kind) const {
+  if (!isCodeCompletion() || spelling.empty())
+    return false;
+  switch (kind) {
+  case Kind::string:
+    return spelling[0] == '"';
+  case Kind::hash_identifier:
+    return spelling[0] == '#';
+  case Kind::percent_identifier:
+    return spelling[0] == '%';
+  case Kind::caret_identifier:
+    return spelling[0] == '^';
+  case Kind::exclamation_identifier:
+    return spelling[0] == '!';
+  default:
+    return false;
+  }
+}

diff  --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h
index be0924cb9d67a..0e48805bb06f0 100644
--- a/mlir/lib/Parser/Token.h
+++ b/mlir/lib/Parser/Token.h
@@ -58,6 +58,19 @@ class Token {
   /// Return true if this is one of the keyword token kinds (e.g. kw_if).
   bool isKeyword() const;
 
+  /// Returns true if the current token represents a code completion.
+  bool isCodeCompletion() const { return is(code_complete); }
+
+  /// Returns true if the current token represents a code completion for the
+  /// "normal" token type.
+  bool isCodeCompletionFor(Kind kind) const;
+
+  /// Returns true if the current token is the given type, or represents a code
+  /// completion for that type.
+  bool isOrIsCodeCompletionFor(Kind kind) const {
+    return is(kind) || isCodeCompletionFor(kind);
+  }
+
   // Helpers to decode specific sorts of tokens.
 
   /// For an integer token, return its value as an unsigned.  If it doesn't fit,

diff  --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def
index f7146674234f6..207af3871f8d3 100644
--- a/mlir/lib/Parser/TokenKinds.def
+++ b/mlir/lib/Parser/TokenKinds.def
@@ -36,6 +36,7 @@
 // Markers
 TOK_MARKER(eof)
 TOK_MARKER(error)
+TOK_MARKER(code_complete)
 
 // Identifiers.
 TOK_IDENTIFIER(bare_identifier)        // foo

diff  --git a/mlir/lib/Tools/lsp-server-support/Protocol.h b/mlir/lib/Tools/lsp-server-support/Protocol.h
index fdbbae2edfa63..03fe3ca6e41e7 100644
--- a/mlir/lib/Tools/lsp-server-support/Protocol.h
+++ b/mlir/lib/Tools/lsp-server-support/Protocol.h
@@ -780,6 +780,11 @@ enum class InsertTextFormat {
 };
 
 struct CompletionItem {
+  CompletionItem() = default;
+  CompletionItem(StringRef label, CompletionItemKind kind)
+      : label(label.str()), kind(kind),
+        insertTextFormat(InsertTextFormat::PlainText) {}
+
   /// The label of this completion item. By default also the text that is
   /// inserted when selecting this completion.
   std::string label;

diff  --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
index e5c24443a9efb..a3b86af4249d2 100644
--- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
@@ -62,6 +62,12 @@ struct LSPServer::Impl {
   void onDocumentSymbol(const DocumentSymbolParams &params,
                         Callback<std::vector<DocumentSymbol>> reply);
 
+  //===--------------------------------------------------------------------===//
+  // Code Completion
+
+  void onCompletion(const CompletionParams &params,
+                    Callback<CompletionList> reply);
+
   //===--------------------------------------------------------------------===//
   // Fields
   //===--------------------------------------------------------------------===//
@@ -91,6 +97,15 @@ void LSPServer::Impl::onInitialize(const InitializeParams &params,
            {"change", (int)TextDocumentSyncKind::Full},
            {"save", true},
        }},
+      {"completionProvider",
+       llvm::json::Object{
+           {"allCommitCharacters",
+            {"\t", "(", ")", "[", "]", "<", ">", ";", ",", "+", "-", "/", "*",
+             "&", "?", ".", "=", "|"}},
+           {"resolveProvider", false},
+           {"triggerCharacters",
+            {".", "%", "^", "!", "#", "(", ",", "<", ":", "[", " ", "\"", "/"}},
+       }},
       {"definitionProvider", true},
       {"referencesProvider", true},
       {"hoverProvider", true},
@@ -192,6 +207,14 @@ void LSPServer::Impl::onDocumentSymbol(
   reply(std::move(symbols));
 }
 
+//===----------------------------------------------------------------------===//
+// Code Completion
+
+void LSPServer::Impl::onCompletion(const CompletionParams &params,
+                                   Callback<CompletionList> reply) {
+  reply(server.getCodeCompletion(params.textDocument.uri, params.position));
+}
+
 //===----------------------------------------------------------------------===//
 // LSPServer
 //===----------------------------------------------------------------------===//
@@ -229,6 +252,10 @@ LogicalResult LSPServer::run() {
   messageHandler.method("textDocument/documentSymbol", impl.get(),
                         &Impl::onDocumentSymbol);
 
+  // Code Completion
+  messageHandler.method("textDocument/completion", impl.get(),
+                        &Impl::onCompletion);
+
   // Diagnostics
   impl->publishDiagnostics =
       messageHandler.outgoingNotification<PublishDiagnosticsParams>(

diff  --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
index f5671692e6a82..c2097747d8452 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
@@ -13,6 +13,7 @@
 #include "mlir/IR/FunctionInterfaces.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Parser/AsmParserState.h"
+#include "mlir/Parser/CodeComplete.h"
 #include "mlir/Parser/Parser.h"
 #include "llvm/Support/SourceMgr.h"
 
@@ -276,6 +277,14 @@ struct MLIRDocument {
   void findDocumentSymbols(Operation *op,
                            std::vector<lsp::DocumentSymbol> &symbols);
 
+  //===--------------------------------------------------------------------===//
+  // Code Completion
+  //===--------------------------------------------------------------------===//
+
+  lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
+                                        const lsp::Position &completePos,
+                                        const DialectRegistry &registry);
+
   //===--------------------------------------------------------------------===//
   // Fields
   //===--------------------------------------------------------------------===//
@@ -615,6 +624,102 @@ void MLIRDocument::findDocumentSymbols(
       findDocumentSymbols(&childOp, *childSymbols);
 }
 
+//===----------------------------------------------------------------------===//
+// MLIRDocument: Code Completion
+//===----------------------------------------------------------------------===//
+
+namespace {
+class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
+public:
+  LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
+                         MLIRContext *ctx)
+      : AsmParserCodeCompleteContext(completeLoc),
+        completionList(completionList), ctx(ctx) {}
+
+  /// Signal code completion for a dialect name.
+  void completeDialectName() final {
+    for (StringRef dialect : ctx->getAvailableDialects()) {
+      lsp::CompletionItem item(dialect, lsp::CompletionItemKind::Module);
+      item.sortText = "2";
+      item.detail = "dialect";
+      completionList.items.emplace_back(item);
+    }
+  }
+
+  /// Signal code completion for an operation name within the given dialect.
+  void completeOperationName(StringRef dialectName) final {
+    Dialect *dialect = ctx->getOrLoadDialect(dialectName);
+    if (!dialect)
+      return;
+
+    for (const auto &op : ctx->getRegisteredOperations()) {
+      if (&op.getDialect() != dialect)
+        continue;
+
+      lsp::CompletionItem item(
+          op.getStringRef().drop_front(dialectName.size() + 1),
+          lsp::CompletionItemKind::Field);
+      item.sortText = "1";
+      item.detail = "operation";
+      completionList.items.emplace_back(item);
+    }
+  }
+
+  /// Append the given SSA value as a code completion result for SSA value
+  /// completions.
+  void appendSSAValueCompletion(StringRef name, std::string typeData) final {
+    // Check if we need to insert the `%` or not.
+    bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
+
+    lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable);
+    if (stripPrefix)
+      item.insertText = name.drop_front(1).str();
+    item.detail = std::move(typeData);
+    completionList.items.emplace_back(item);
+  }
+
+  /// Append the given block as a code completion result for block name
+  /// completions.
+  void appendBlockCompletion(StringRef name) final {
+    // Check if we need to insert the `^` or not.
+    bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
+
+    lsp::CompletionItem item(name, lsp::CompletionItemKind::Field);
+    if (stripPrefix)
+      item.insertText = name.drop_front(1).str();
+    completionList.items.emplace_back(item);
+  }
+
+private:
+  lsp::CompletionList &completionList;
+  MLIRContext *ctx;
+};
+} // namespace
+
+lsp::CompletionList
+MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
+                                const lsp::Position &completePos,
+                                const DialectRegistry &registry) {
+  SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
+  if (!posLoc.isValid())
+    return lsp::CompletionList();
+
+  // To perform code completion, we run another parse of the module with the
+  // code completion context provided.
+  MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
+  tmpContext.allowUnregisteredDialects();
+  lsp::CompletionList completionList;
+  LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
+                                            &tmpContext);
+
+  Block tmpIR;
+  AsmParserState tmpState;
+  (void)parseSourceFile(sourceMgr, &tmpIR, &tmpContext,
+                        /*sourceFileLoc=*/nullptr, &tmpState,
+                        &lspCompleteContext);
+  return completionList;
+}
+
 //===----------------------------------------------------------------------===//
 // MLIRTextFileChunk
 //===----------------------------------------------------------------------===//
@@ -670,6 +775,8 @@ class MLIRTextFile {
   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
                                  lsp::Position hoverPos);
   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
+  lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
+                                        lsp::Position completePos);
 
 private:
   /// Find the MLIR document that contains the given position, and update the
@@ -813,6 +920,22 @@ void MLIRTextFile::findDocumentSymbols(
   }
 }
 
+lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
+                                                    lsp::Position completePos) {
+  MLIRTextFileChunk &chunk = getChunkFor(completePos);
+  lsp::CompletionList completionList = chunk.document.getCodeCompletion(
+      uri, completePos, context.getDialectRegistry());
+
+  // Adjust any completion locations.
+  for (lsp::CompletionItem &item : completionList.items) {
+    if (item.textEdit)
+      chunk.adjustLocForChunkOffset(item.textEdit->range);
+    for (lsp::TextEdit &edit : item.additionalTextEdits)
+      chunk.adjustLocForChunkOffset(edit.range);
+  }
+  return completionList;
+}
+
 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
   if (chunks.size() == 1)
     return *chunks.front();
@@ -898,3 +1021,12 @@ void lsp::MLIRServer::findDocumentSymbols(
   if (fileIt != impl->files.end())
     fileIt->second->findDocumentSymbols(symbols);
 }
+
+lsp::CompletionList
+lsp::MLIRServer::getCodeCompletion(const URIForFile &uri,
+                                   const Position &completePos) {
+  auto fileIt = impl->files.find(uri.file());
+  if (fileIt != impl->files.end())
+    return fileIt->second->getCodeCompletion(uri, completePos);
+  return CompletionList();
+}

diff  --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h
index d78d8c3dc6ec8..85ccf0a68f1b0 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h
@@ -16,6 +16,7 @@ namespace mlir {
 class DialectRegistry;
 
 namespace lsp {
+struct CompletionList;
 struct Diagnostic;
 struct DocumentSymbol;
 struct Hover;
@@ -60,6 +61,10 @@ class MLIRServer {
   void findDocumentSymbols(const URIForFile &uri,
                            std::vector<DocumentSymbol> &symbols);
 
+  /// Get the code completion list for the position within the given file.
+  CompletionList getCodeCompletion(const URIForFile &uri,
+                                   const Position &completePos);
+
 private:
   struct Impl;
 

diff  --git a/mlir/test/mlir-lsp-server/completion.test b/mlir/test/mlir-lsp-server/completion.test
new file mode 100644
index 0000000000000..b13dbe147f057
--- /dev/null
+++ b/mlir/test/mlir-lsp-server/completion.test
@@ -0,0 +1,105 @@
+// RUN: mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s
+{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"mlir","capabilities":{},"trace":"off"}}
+// -----
+{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
+  "uri":"test:///foo.mlir",
+  "languageId":"mlir",
+  "version":1,
+  "text":"func.func private @foo(%arg: i32) -> i32 {\n%cast = \"builtin.unrealized_conversion_cast\"() : () -> (i32)\nreturn %"
+}}}
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+  "textDocument":{"uri":"test:///foo.mlir"},
+  "position":{"line":0,"character":0}
+}}
+//      CHECK:  "id": 1
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "isIncomplete": false,
+// CHECK-NEXT:    "items": [
+// CHECK:           {
+// CHECK:             "detail": "dialect",
+// CHECK:             "insertTextFormat": 1,
+// CHECK:             "kind": 9,
+// CHECK:             "label": "builtin",
+// CHECK:             "sortText": "2"
+// CHECK:           },
+// CHECK:           {
+// CHECK:             "detail": "operation",
+// CHECK:             "insertTextFormat": 1,
+// CHECK:             "kind": 5,
+// CHECK:             "label": "module",
+// CHECK:             "sortText": "1"
+// CHECK:           }
+// CHECK:         ]
+// CHECK-NEXT:  }
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+  "textDocument":{"uri":"test:///foo.mlir"},
+  "position":{"line":1,"character":9}
+}}
+//      CHECK:  "id": 1
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "isIncomplete": false,
+// CHECK-NEXT:    "items": [
+// CHECK:           {
+// CHECK:             "detail": "dialect",
+// CHECK:             "insertTextFormat": 1,
+// CHECK:             "kind": 9,
+// CHECK:             "label": "builtin",
+// CHECK:             "sortText": "2"
+// CHECK:           },
+// CHECK-NOT:       "detail": "operation",
+// CHECK:         ]
+// CHECK-NEXT:  }
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+  "textDocument":{"uri":"test:///foo.mlir"},
+  "position":{"line":1,"character":17}
+}}
+//      CHECK:  "id": 1
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "isIncomplete": false,
+// CHECK-NEXT:    "items": [
+// CHECK-NOT:       "detail": "dialect",
+// CHECK:           {
+// CHECK:             "detail": "operation",
+// CHECK:             "insertTextFormat": 1,
+// CHECK:             "kind": 5,
+// CHECK:             "label": "module",
+// CHECK:             "sortText": "1"
+// CHECK:           }
+// CHECK:         ]
+// CHECK-NEXT:  }
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+  "textDocument":{"uri":"test:///foo.mlir"},
+  "position":{"line":2,"character":8}
+}}
+//      CHECK:  "id": 1
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "isIncomplete": false,
+// CHECK-NEXT:    "items": [
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "builtin.unrealized_conversion_cast: i32",
+// CHECK-NEXT:        "insertText": "cast",
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 6,
+// CHECK-NEXT:        "label": "%cast"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "arg #0: i32",
+// CHECK-NEXT:        "insertText": "arg",
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 6,
+// CHECK-NEXT:        "label": "%arg"
+// CHECK-NEXT:      }
+// CHECK:         ]
+// CHECK-NEXT:  }
+// -----
+{"jsonrpc":"2.0","id":3,"method":"shutdown"}
+// -----
+{"jsonrpc":"2.0","method":"exit"}

diff  --git a/mlir/test/mlir-lsp-server/initialize-params.test b/mlir/test/mlir-lsp-server/initialize-params.test
index db41a61a8a1a8..2fe528f8ed2f2 100644
--- a/mlir/test/mlir-lsp-server/initialize-params.test
+++ b/mlir/test/mlir-lsp-server/initialize-params.test
@@ -5,6 +5,13 @@
 // CHECK-NEXT:  "jsonrpc": "2.0",
 // CHECK-NEXT:  "result": {
 // CHECK-NEXT:    "capabilities": {
+// CHECK-NEXT:      "completionProvider": {
+// CHECK-NEXT:        "allCommitCharacters": [
+// CHECK:             ],
+// CHECK-NEXT:        "resolveProvider": false,
+// CHECK-NEXT:        "triggerCharacters": [ 
+// CHECK:             ]
+// CHECK-NEXT:      },
 // CHECK-NEXT:      "definitionProvider": true,
 // CHECK-NEXT:      "documentSymbolProvider": false,
 // CHECK-NEXT:      "hoverProvider": true,


        


More information about the Mlir-commits mailing list