[Mlir-commits] [mlir] 008de48 - [mlir][PDLL] Add code completion to the PDLL language server
River Riddle
llvmlistbot at llvm.org
Sat Mar 19 13:29:34 PDT 2022
Author: River Riddle
Date: 2022-03-19T13:28:24-07:00
New Revision: 008de486f706ef25a66d4384c2c3af1ed86e680e
URL: https://github.com/llvm/llvm-project/commit/008de486f706ef25a66d4384c2c3af1ed86e680e
DIFF: https://github.com/llvm/llvm-project/commit/008de486f706ef25a66d4384c2c3af1ed86e680e.diff
LOG: [mlir][PDLL] Add code completion to the PDLL language server
This commit adds code completion support to the language server,
and initially supports providing completions for: Member access,
attributes/constraint/dialect/operation names, and pattern metadata.
Differential Revision: https://reviews.llvm.org/D121544
Added:
mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h
mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp
mlir/test/mlir-pdll-lsp-server/completion.test
Modified:
mlir/include/mlir/Tools/PDLL/ODS/Context.h
mlir/include/mlir/Tools/PDLL/Parser/Parser.h
mlir/lib/Tools/PDLL/Parser/CMakeLists.txt
mlir/lib/Tools/PDLL/Parser/Lexer.cpp
mlir/lib/Tools/PDLL/Parser/Lexer.h
mlir/lib/Tools/PDLL/Parser/Parser.cpp
mlir/lib/Tools/lsp-server-support/Protocol.cpp
mlir/lib/Tools/lsp-server-support/Protocol.h
mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h
mlir/test/mlir-pdll-lsp-server/initialize-params.test
Removed:
################################################################################
diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Context.h b/mlir/include/mlir/Tools/PDLL/ODS/Context.h
index d0955ab62d8e7..ea3751eac3541 100644
--- a/mlir/include/mlir/Tools/PDLL/ODS/Context.h
+++ b/mlir/include/mlir/Tools/PDLL/ODS/Context.h
@@ -53,6 +53,11 @@ class Context {
/// with that name was inserted.
const Dialect *lookupDialect(StringRef name) const;
+ /// Return a range of all of the registered dialects.
+ auto getDialects() const {
+ return llvm::make_pointee_range(llvm::make_second_range(dialects));
+ }
+
/// Insert a new operation with the context. Returns the inserted operation,
/// and a boolean indicating if the operation newly inserted (false if the
/// operation already existed).
diff --git a/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h b/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h
new file mode 100644
index 0000000000000..f97310b3ecbeb
--- /dev/null
+++ b/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h
@@ -0,0 +1,82 @@
+//===- CodeComplete.h - PDLL Frontend CodeComplete Context ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_PDLL_PARSER_CODECOMPLETE_H_
+#define MLIR_TOOLS_PDLL_PARSER_CODECOMPLETE_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir {
+namespace pdll {
+namespace ast {
+class CallableDecl;
+class DeclScope;
+class Expr;
+class OperationType;
+class TupleType;
+class Type;
+class VariableDecl;
+} // namespace ast
+
+/// This class provides an abstract interface into the parser for hooking in
+/// code completion events.
+class CodeCompleteContext {
+public:
+ virtual ~CodeCompleteContext();
+
+ /// Return the location used to provide code completion.
+ SMLoc getCodeCompleteLoc() const { return codeCompleteLoc; }
+
+ //===--------------------------------------------------------------------===//
+ // Completion Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// Signal code completion for a member access into the given tuple type.
+ virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType);
+
+ /// Signal code completion for a member access into the given operation type.
+ virtual void codeCompleteOperationMemberAccess(ast::OperationType opType);
+
+ /// Signal code completion for a member access into the given operation type.
+ virtual void codeCompleteOperationAttributeName(StringRef opName) {}
+
+ /// Signal code completion for a constraint name with an optional decl scope.
+ /// `currentType` is the current type of the variable that will use the
+ /// constraint, or nullptr if a type is unknown. `allowNonCoreConstraints`
+ /// indicates if user defined constraints are allowed in the completion
+ /// results. `allowInlineTypeConstraints` enables inline type constraints for
+ /// Attr/Value/ValueRange.
+ virtual void codeCompleteConstraintName(ast::Type currentType,
+ bool allowNonCoreConstraints,
+ bool allowInlineTypeConstraints,
+ const ast::DeclScope *scope);
+
+ /// Signal code completion for a dialect name.
+ virtual void codeCompleteDialectName() {}
+
+ /// Signal code completion for an operation name in the given dialect.
+ virtual void codeCompleteOperationName(StringRef dialectName) {}
+
+ /// Signal code completion for Pattern metadata.
+ virtual void codeCompletePatternMetadata() {}
+
+protected:
+ /// Create a new code completion context with the given code complete
+ /// location.
+ explicit CodeCompleteContext(SMLoc codeCompleteLoc)
+ : codeCompleteLoc(codeCompleteLoc) {}
+
+private:
+ /// The location used to code complete.
+ SMLoc codeCompleteLoc;
+};
+} // namespace pdll
+} // namespace mlir
+
+#endif // MLIR_TOOLS_PDLL_PARSER_CODECOMPLETE_H_
diff --git a/mlir/include/mlir/Tools/PDLL/Parser/Parser.h b/mlir/include/mlir/Tools/PDLL/Parser/Parser.h
index a0b269aeb6eb1..ce5815a478d1d 100644
--- a/mlir/include/mlir/Tools/PDLL/Parser/Parser.h
+++ b/mlir/include/mlir/Tools/PDLL/Parser/Parser.h
@@ -19,14 +19,19 @@ class SourceMgr;
namespace mlir {
namespace pdll {
+class CodeCompleteContext;
+
namespace ast {
class Context;
class Module;
} // namespace ast
-/// Parse an AST module from the main file of the given source manager.
-FailureOr<ast::Module *> parsePDLAST(ast::Context &ctx,
- llvm::SourceMgr &sourceMgr);
+/// Parse an AST module from the main file of the given source manager. An
+/// optional code completion context may be provided to receive code completion
+/// suggestions. If a completion is hit, this method returns a failure.
+FailureOr<ast::Module *>
+parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
+ CodeCompleteContext *codeCompleteContext = nullptr);
} // namespace pdll
} // namespace mlir
diff --git a/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt b/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt
index f705214eaca60..7953677d1957e 100644
--- a/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt
+++ b/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt
@@ -1,5 +1,6 @@
llvm_add_library(MLIRPDLLParser STATIC
+ CodeComplete.cpp
Lexer.cpp
Parser.cpp
diff --git a/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp b/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp
new file mode 100644
index 0000000000000..acc2ca84037dc
--- /dev/null
+++ b/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp
@@ -0,0 +1,28 @@
+//===- CodeComplete.cpp ---------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
+#include "mlir/Tools/PDLL/AST/Types.h"
+
+using namespace mlir;
+using namespace mlir::pdll;
+
+//===----------------------------------------------------------------------===//
+// CodeCompleteContext
+//===----------------------------------------------------------------------===//
+
+CodeCompleteContext::~CodeCompleteContext() = default;
+
+void CodeCompleteContext::codeCompleteTupleMemberAccess(
+ ast::TupleType tupleType) {}
+void CodeCompleteContext::codeCompleteOperationMemberAccess(
+ ast::OperationType opType) {}
+
+void CodeCompleteContext::codeCompleteConstraintName(
+ ast::Type currentType, bool allowNonCoreConstraints,
+ bool allowInlineTypeConstraints, const ast::DeclScope *scope) {}
diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
index 61c5783e24545..efae2b46af8d0 100644
--- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
@@ -9,6 +9,7 @@
#include "Lexer.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
+#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/SourceMgr.h"
@@ -67,12 +68,20 @@ std::string Token::getStringValue() const {
// Lexer
//===----------------------------------------------------------------------===//
-Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine)
- : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false) {
+Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
+ CodeCompleteContext *codeCompleteContext)
+ : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false),
+ codeCompletionLocation(nullptr) {
curBufferID = mgr.getMainFileID();
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
curPtr = curBuffer.begin();
+ // Set the code completion location if necessary.
+ if (codeCompleteContext) {
+ codeCompletionLocation =
+ codeCompleteContext->getCodeCompleteLoc().getPointer();
+ }
+
// If the diag engine has no handler, add a default that emits to the
// SourceMgr.
if (!diagEngine.getHandlerFn()) {
@@ -147,6 +156,10 @@ Token Lexer::lexToken() {
while (true) {
const char *tokStart = curPtr;
+ // Check to see if this token is at the code completion location.
+ if (tokStart == codeCompletionLocation)
+ return formToken(Token::code_complete, tokStart);
+
// This always consumes at least one character.
int curChar = getNextChar();
switch (curChar) {
diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h
index 0109b0da36c79..07eb2e2e4c418 100644
--- a/mlir/lib/Tools/PDLL/Parser/Lexer.h
+++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h
@@ -21,6 +21,8 @@ namespace mlir {
struct LogicalResult;
namespace pdll {
+class CodeCompleteContext;
+
namespace ast {
class DiagnosticEngine;
} // namespace ast
@@ -35,6 +37,7 @@ class Token {
// Markers.
eof,
error,
+ code_complete,
// Keywords.
KW_BEGIN,
@@ -162,7 +165,8 @@ class Token {
class Lexer {
public:
- Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine);
+ Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
+ CodeCompleteContext *codeCompleteContext);
~Lexer();
/// Return a reference to the source manager used by the lexer.
@@ -215,6 +219,9 @@ class Lexer {
/// A flag indicating if we added a default diagnostic handler to the provided
/// diagEngine.
bool addedHandlerToDiagEngine;
+
+ /// The optional code completion point within the input file.
+ const char *codeCompletionLocation;
};
} // namespace pdll
} // namespace mlir
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 167aee27f8747..3d1f2dc3ce6fe 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -21,6 +21,7 @@
#include "mlir/Tools/PDLL/ODS/Constraint.h"
#include "mlir/Tools/PDLL/ODS/Context.h"
#include "mlir/Tools/PDLL/ODS/Operation.h"
+#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
@@ -41,13 +42,15 @@ using namespace mlir::pdll;
namespace {
class Parser {
public:
- Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
- : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
+ Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
+ CodeCompleteContext *codeCompleteContext)
+ : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)),
valueRangeTy(ast::ValueRangeType::get(ctx)),
typeTy(ast::TypeType::get(ctx)),
typeRangeTy(ast::TypeRangeType::get(ctx)),
- attrTy(ast::AttributeType::get(ctx)) {}
+ attrTy(ast::AttributeType::get(ctx)),
+ codeCompleteContext(codeCompleteContext) {}
/// Try to parse a new module. Returns nullptr in the case of failure.
FailureOr<ast::Module *> parseModule();
@@ -142,7 +145,8 @@ class Parser {
};
FailureOr<ast::Decl *> parseTopLevelDecl();
- FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
+ FailureOr<ast::NamedAttributeDecl *>
+ parseNamedAttributeDecl(Optional<StringRef> parentOpName);
/// Parse an argument variable as part of the signature of a
/// UserConstraintDecl or UserRewriteDecl.
@@ -248,10 +252,13 @@ class Parser {
/// existing constraints that have already been parsed for the same entity
/// that will be constrained by this constraint. `allowInlineTypeConstraints`
/// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
+ /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined
+ /// constraints) may be used with the variable.
FailureOr<ast::ConstraintRef>
parseConstraint(Optional<SMRange> &typeConstraint,
ArrayRef<ast::ConstraintRef> existingConstraints,
- bool allowInlineTypeConstraints);
+ bool allowInlineTypeConstraints,
+ bool allowNonCoreConstraints);
/// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
/// argument or result variable. The constraints for these variables do not
@@ -335,10 +342,12 @@ class Parser {
/// `inferredType` is the type of the variable inferred by the constraints
/// within the list, and is updated to the most refined type as determined by
/// the constraints. Returns success if the constraint list is valid, failure
- /// otherwise.
+ /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
+ /// defined constraints) may be used with the variable.
LogicalResult
validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
- ast::Type &inferredType);
+ ast::Type &inferredType,
+ bool allowNonCoreConstraints = true);
/// Validate a single reference to a constraint. `inferredType` contains the
/// currently inferred variabled type and is refined within the type defined
/// by the constraint. Returns success if the constraint is valid, failure
@@ -399,6 +408,23 @@ class Parser {
createRewriteStmt(SMRange loc, ast::Expr *rootOp,
ast::CompoundStmt *rewriteBody);
+ //===--------------------------------------------------------------------===//
+ // Code Completion
+ //===--------------------------------------------------------------------===//
+
+ /// The set of various code completion methods. Every completion method
+ /// returns `failure` to stop the parsing process after providing completion
+ /// results.
+
+ LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
+ LogicalResult codeCompleteAttributeName(Optional<StringRef> opName);
+ LogicalResult codeCompleteConstraintName(ast::Type inferredType,
+ bool allowNonCoreConstraints,
+ bool allowInlineTypeConstraints);
+ LogicalResult codeCompleteDialectName();
+ LogicalResult codeCompleteOperationName(StringRef dialectName);
+ LogicalResult codeCompletePatternMetadata();
+
//===--------------------------------------------------------------------===//
// Lexer Utilities
//===--------------------------------------------------------------------===//
@@ -481,6 +507,9 @@ class Parser {
/// A counter used when naming anonymous constraints and rewrites.
unsigned anonymousDeclNameCounter = 0;
+
+ /// The optional code completion context.
+ CodeCompleteContext *codeCompleteContext;
};
} // namespace
@@ -890,7 +919,12 @@ FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
return decl;
}
-FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
+FailureOr<ast::NamedAttributeDecl *>
+Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) {
+ // Check for name code completion.
+ if (curToken.is(Token::code_complete))
+ return codeCompleteAttributeName(parentOpName);
+
std::string attrNameStr;
if (curToken.isString())
attrNameStr = curToken.getStringValue();
@@ -1380,6 +1414,10 @@ Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
Optional<SMRange> hasBoundedRecursionLoc;
do {
+ // Handle metadata code completion.
+ if (curToken.is(Token::code_complete))
+ return codeCompletePatternMetadata();
+
if (curToken.isNot(Token::identifier))
return emitError("expected pattern metadata identifier");
StringRef metadataStr = curToken.getSpelling();
@@ -1488,7 +1526,8 @@ LogicalResult Parser::parseVariableDeclConstraintList(
Optional<SMRange> typeConstraint;
auto parseSingleConstraint = [&] {
FailureOr<ast::ConstraintRef> constraint = parseConstraint(
- typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
+ typeConstraint, constraints, /*allowInlineTypeConstraints=*/true,
+ /*allowNonCoreConstraints=*/true);
if (failed(constraint))
return failure();
constraints.push_back(*constraint);
@@ -1509,7 +1548,8 @@ LogicalResult Parser::parseVariableDeclConstraintList(
FailureOr<ast::ConstraintRef>
Parser::parseConstraint(Optional<SMRange> &typeConstraint,
ArrayRef<ast::ConstraintRef> existingConstraints,
- bool allowInlineTypeConstraints) {
+ bool allowInlineTypeConstraints,
+ bool allowNonCoreConstraints) {
auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
if (!allowInlineTypeConstraints) {
return emitError(
@@ -1610,6 +1650,17 @@ Parser::parseConstraint(Optional<SMRange> &typeConstraint,
return emitErrorAndNote(
loc, "invalid reference to non-constraint", cstDecl->getLoc(),
"see the definition of `" + constraintName + "` here");
+ }
+ // Handle single entity constraint code completion.
+ case Token::code_complete: {
+ // Try to infer the current type for use by code completion.
+ ast::Type inferredType;
+ if (failed(validateVariableConstraints(existingConstraints, inferredType,
+ allowNonCoreConstraints)))
+ return failure();
+
+ return codeCompleteConstraintName(inferredType, allowNonCoreConstraints,
+ allowInlineTypeConstraints);
}
default:
break;
@@ -1618,9 +1669,13 @@ Parser::parseConstraint(Optional<SMRange> &typeConstraint,
}
FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
+ // Constraint arguments may apply more complex constraints via the arguments.
+ bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
+
Optional<SMRange> typeConstraint;
return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
- /*allowInlineTypeConstraints=*/false);
+ /*allowInlineTypeConstraints=*/false,
+ allowNonCoreConstraints);
}
//===----------------------------------------------------------------------===//
@@ -1770,6 +1825,10 @@ FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
SMRange loc = curToken.getLoc();
consumeToken(Token::dot);
+ // Check for code completion of the member name.
+ if (curToken.is(Token::code_complete))
+ return codeCompleteMemberAccess(parentExpr);
+
// Parse the member name.
Token memberNameTok = curToken;
if (memberNameTok.isNot(Token::identifier, Token::integer) &&
@@ -1784,6 +1843,10 @@ FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
SMRange loc = curToken.getLoc();
+ // Check for code completion for the dialect name.
+ if (curToken.is(Token::code_complete))
+ return codeCompleteDialectName();
+
// Handle the case of an no operation name.
if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
if (allowEmptyName)
@@ -1797,6 +1860,10 @@ FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
return failure();
+ // Check for code completion for the operation name.
+ if (curToken.is(Token::code_complete))
+ return codeCompleteOperationName(name);
+
if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
return emitError("expected operation name after dialect namespace");
@@ -1843,6 +1910,7 @@ FailureOr<ast::Expr *> Parser::parseOperationExpr() {
parseWrappedOperationName(allowEmptyName);
if (failed(opNameDecl))
return failure();
+ Optional<StringRef> opName = (*opNameDecl)->getName();
// Functor used to create an implicit range variable, used for implicit "all"
// operand or results variables.
@@ -1882,7 +1950,8 @@ FailureOr<ast::Expr *> Parser::parseOperationExpr() {
SmallVector<ast::NamedAttributeDecl *> attributes;
if (consumeIf(Token::l_brace)) {
do {
- FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
+ FailureOr<ast::NamedAttributeDecl *> decl =
+ parseNamedAttributeDecl(opName);
if (failed(decl))
return failure();
attributes.emplace_back(*decl);
@@ -2362,9 +2431,11 @@ Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
LogicalResult
Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
- ast::Type &inferredType) {
+ ast::Type &inferredType,
+ bool allowNonCoreConstraints) {
for (const ast::ConstraintRef &ref : constraints)
- if (failed(validateVariableConstraint(ref, inferredType)))
+ if (failed(validateVariableConstraint(ref, inferredType,
+ allowNonCoreConstraints)))
return failure();
return success();
}
@@ -2784,12 +2855,57 @@ Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
}
+//===----------------------------------------------------------------------===//
+// Code Completion
+//===----------------------------------------------------------------------===//
+
+LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
+ ast::Type parentType = parentExpr->getType();
+ if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
+ codeCompleteContext->codeCompleteOperationMemberAccess(opType);
+ else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
+ codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
+ return failure();
+}
+
+LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) {
+ if (opName)
+ codeCompleteContext->codeCompleteOperationAttributeName(*opName);
+ return failure();
+}
+
+LogicalResult
+Parser::codeCompleteConstraintName(ast::Type inferredType,
+ bool allowNonCoreConstraints,
+ bool allowInlineTypeConstraints) {
+ codeCompleteContext->codeCompleteConstraintName(
+ inferredType, allowNonCoreConstraints, allowInlineTypeConstraints,
+ curDeclScope);
+ return failure();
+}
+
+LogicalResult Parser::codeCompleteDialectName() {
+ codeCompleteContext->codeCompleteDialectName();
+ return failure();
+}
+
+LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
+ codeCompleteContext->codeCompleteOperationName(dialectName);
+ return failure();
+}
+
+LogicalResult Parser::codeCompletePatternMetadata() {
+ codeCompleteContext->codeCompletePatternMetadata();
+ return failure();
+}
+
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
-FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
- llvm::SourceMgr &sourceMgr) {
- Parser parser(ctx, sourceMgr);
+FailureOr<ast::Module *>
+mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
+ CodeCompleteContext *codeCompleteContext) {
+ Parser parser(ctx, sourceMgr, codeCompleteContext);
return parser.parseModule();
}
diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.cpp b/mlir/lib/Tools/lsp-server-support/Protocol.cpp
index 7bce042f8900c..2b9ba3b853f7f 100644
--- a/mlir/lib/Tools/lsp-server-support/Protocol.cpp
+++ b/mlir/lib/Tools/lsp-server-support/Protocol.cpp
@@ -585,3 +585,160 @@ llvm::json::Value mlir::lsp::toJSON(const PublishDiagnosticsParams ¶ms) {
{"version", params.version},
};
}
+
+//===----------------------------------------------------------------------===//
+// TextEdit
+//===----------------------------------------------------------------------===//
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value, TextEdit &result,
+ llvm::json::Path path) {
+ llvm::json::ObjectMapper o(value, path);
+ return o && o.map("range", result.range) && o.map("newText", result.newText);
+}
+
+llvm::json::Value mlir::lsp::toJSON(const TextEdit &value) {
+ return llvm::json::Object{
+ {"range", value.range},
+ {"newText", value.newText},
+ };
+}
+
+raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const TextEdit &value) {
+ os << value.range << " => \"";
+ llvm::printEscapedString(value.newText, os);
+ return os << '"';
+}
+
+//===----------------------------------------------------------------------===//
+// CompletionItemKind
+//===----------------------------------------------------------------------===//
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value,
+ CompletionItemKind &result, llvm::json::Path path) {
+ if (Optional<int64_t> intValue = value.getAsInteger()) {
+ if (*intValue < static_cast<int>(CompletionItemKind::Text) ||
+ *intValue > static_cast<int>(CompletionItemKind::TypeParameter))
+ return false;
+ result = static_cast<CompletionItemKind>(*intValue);
+ return true;
+ }
+ return false;
+}
+
+CompletionItemKind mlir::lsp::adjustKindToCapability(
+ CompletionItemKind kind,
+ CompletionItemKindBitset &supportedCompletionItemKinds) {
+ size_t kindVal = static_cast<size_t>(kind);
+ if (kindVal >= kCompletionItemKindMin &&
+ kindVal <= supportedCompletionItemKinds.size() &&
+ supportedCompletionItemKinds[kindVal])
+ return kind;
+
+ // Provide some fall backs for common kinds that are close enough.
+ switch (kind) {
+ case CompletionItemKind::Folder:
+ return CompletionItemKind::File;
+ case CompletionItemKind::EnumMember:
+ return CompletionItemKind::Enum;
+ case CompletionItemKind::Struct:
+ return CompletionItemKind::Class;
+ default:
+ return CompletionItemKind::Text;
+ }
+}
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value,
+ CompletionItemKindBitset &result,
+ llvm::json::Path path) {
+ if (const llvm::json::Array *arrayValue = value.getAsArray()) {
+ for (size_t i = 0, e = arrayValue->size(); i < e; ++i) {
+ CompletionItemKind kindOut;
+ if (fromJSON((*arrayValue)[i], kindOut, path.index(i)))
+ result.set(size_t(kindOut));
+ }
+ return true;
+ }
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// CompletionItem
+//===----------------------------------------------------------------------===//
+
+llvm::json::Value mlir::lsp::toJSON(const CompletionItem &value) {
+ assert(!value.label.empty() && "completion item label is required");
+ llvm::json::Object result{{"label", value.label}};
+ if (value.kind != CompletionItemKind::Missing)
+ result["kind"] = static_cast<int>(value.kind);
+ if (!value.detail.empty())
+ result["detail"] = value.detail;
+ if (value.documentation)
+ result["documentation"] = value.documentation;
+ if (!value.sortText.empty())
+ result["sortText"] = value.sortText;
+ if (!value.filterText.empty())
+ result["filterText"] = value.filterText;
+ if (!value.insertText.empty())
+ result["insertText"] = value.insertText;
+ if (value.insertTextFormat != InsertTextFormat::Missing)
+ result["insertTextFormat"] = static_cast<int>(value.insertTextFormat);
+ if (value.textEdit)
+ result["textEdit"] = *value.textEdit;
+ if (!value.additionalTextEdits.empty()) {
+ result["additionalTextEdits"] =
+ llvm::json::Array(value.additionalTextEdits);
+ }
+ if (value.deprecated)
+ result["deprecated"] = value.deprecated;
+ return std::move(result);
+}
+
+raw_ostream &mlir::lsp::operator<<(raw_ostream &os,
+ const CompletionItem &value) {
+ return os << value.label << " - " << toJSON(value);
+}
+
+bool mlir::lsp::operator<(const CompletionItem &lhs,
+ const CompletionItem &rhs) {
+ return (lhs.sortText.empty() ? lhs.label : lhs.sortText) <
+ (rhs.sortText.empty() ? rhs.label : rhs.sortText);
+}
+
+//===----------------------------------------------------------------------===//
+// CompletionList
+//===----------------------------------------------------------------------===//
+
+llvm::json::Value mlir::lsp::toJSON(const CompletionList &value) {
+ return llvm::json::Object{
+ {"isIncomplete", value.isIncomplete},
+ {"items", llvm::json::Array(value.items)},
+ };
+}
+
+//===----------------------------------------------------------------------===//
+// CompletionContext
+//===----------------------------------------------------------------------===//
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value,
+ CompletionContext &result, llvm::json::Path path) {
+ llvm::json::ObjectMapper o(value, path);
+ int triggerKind;
+ if (!o || !o.map("triggerKind", triggerKind) ||
+ !mapOptOrNull(value, "triggerCharacter", result.triggerCharacter, path))
+ return false;
+ result.triggerKind = static_cast<CompletionTriggerKind>(triggerKind);
+ return true;
+}
+
+//===----------------------------------------------------------------------===//
+// CompletionParams
+//===----------------------------------------------------------------------===//
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value,
+ CompletionParams &result, llvm::json::Path path) {
+ if (!fromJSON(value, static_cast<TextDocumentPositionParams &>(result), path))
+ return false;
+ if (const llvm::json::Value *context = value.getAsObject()->get("context"))
+ return fromJSON(*context, result.context, path.field("context"));
+ return true;
+}
diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.h b/mlir/lib/Tools/lsp-server-support/Protocol.h
index 6cb1dc4d50f6c..cc8c52abb104e 100644
--- a/mlir/lib/Tools/lsp-server-support/Protocol.h
+++ b/mlir/lib/Tools/lsp-server-support/Protocol.h
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file contains structs based on the LSP specification at
-// https://github.com/Microsoft/language-server-protocol/blob/main/protocol.md
+// https://microsoft.github.io/language-server-protocol/specification
//
// This is not meant to be a complete implementation, new interfaces are added
// when they're needed.
@@ -666,6 +666,211 @@ struct PublishDiagnosticsParams {
/// Add support for JSON serialization.
llvm::json::Value toJSON(const PublishDiagnosticsParams ¶ms);
+//===----------------------------------------------------------------------===//
+// TextEdit
+//===----------------------------------------------------------------------===//
+
+struct TextEdit {
+ /// The range of the text document to be manipulated. To insert
+ /// text into a document create a range where start === end.
+ Range range;
+
+ /// The string to be inserted. For delete operations use an
+ /// empty string.
+ std::string newText;
+};
+
+inline bool operator==(const TextEdit &lhs, const TextEdit &rhs) {
+ return std::tie(lhs.newText, lhs.range) == std::tie(rhs.newText, rhs.range);
+}
+
+bool fromJSON(const llvm::json::Value &value, TextEdit &result,
+ llvm::json::Path path);
+llvm::json::Value toJSON(const TextEdit &value);
+raw_ostream &operator<<(raw_ostream &os, const TextEdit &value);
+
+//===----------------------------------------------------------------------===//
+// CompletionItemKind
+//===----------------------------------------------------------------------===//
+
+/// The kind of a completion entry.
+enum class CompletionItemKind {
+ Missing = 0,
+ Text = 1,
+ Method = 2,
+ Function = 3,
+ Constructor = 4,
+ Field = 5,
+ Variable = 6,
+ Class = 7,
+ Interface = 8,
+ Module = 9,
+ Property = 10,
+ Unit = 11,
+ Value = 12,
+ Enum = 13,
+ Keyword = 14,
+ Snippet = 15,
+ Color = 16,
+ File = 17,
+ Reference = 18,
+ Folder = 19,
+ EnumMember = 20,
+ Constant = 21,
+ Struct = 22,
+ Event = 23,
+ Operator = 24,
+ TypeParameter = 25,
+};
+bool fromJSON(const llvm::json::Value &value, CompletionItemKind &result,
+ llvm::json::Path path);
+
+constexpr auto kCompletionItemKindMin =
+ static_cast<size_t>(CompletionItemKind::Text);
+constexpr auto kCompletionItemKindMax =
+ static_cast<size_t>(CompletionItemKind::TypeParameter);
+using CompletionItemKindBitset = std::bitset<kCompletionItemKindMax + 1>;
+bool fromJSON(const llvm::json::Value &value, CompletionItemKindBitset &result,
+ llvm::json::Path path);
+
+CompletionItemKind
+adjustKindToCapability(CompletionItemKind kind,
+ CompletionItemKindBitset &supportedCompletionItemKinds);
+
+//===----------------------------------------------------------------------===//
+// CompletionItem
+//===----------------------------------------------------------------------===//
+
+/// Defines whether the insert text in a completion item should be interpreted
+/// as plain text or a snippet.
+enum class InsertTextFormat {
+ Missing = 0,
+ /// The primary text to be inserted is treated as a plain string.
+ PlainText = 1,
+ /// The primary text to be inserted is treated as a snippet.
+ ///
+ /// A snippet can define tab stops and placeholders with `$1`, `$2`
+ /// and `${3:foo}`. `$0` defines the final tab stop, it defaults to the end
+ /// of the snippet. Placeholders with equal identifiers are linked, that is
+ /// typing in one will update others too.
+ ///
+ /// See also:
+ /// https//github.com/Microsoft/vscode/blob/master/src/vs/editor/contrib/snippet/common/snippet.md
+ Snippet = 2,
+};
+
+struct CompletionItem {
+ /// The label of this completion item. By default also the text that is
+ /// inserted when selecting this completion.
+ std::string label;
+
+ /// The kind of this completion item. Based of the kind an icon is chosen by
+ /// the editor.
+ CompletionItemKind kind = CompletionItemKind::Missing;
+
+ /// A human-readable string with additional information about this item, like
+ /// type or symbol information.
+ std::string detail;
+
+ /// A human-readable string that represents a doc-comment.
+ Optional<MarkupContent> documentation;
+
+ /// A string that should be used when comparing this item with other items.
+ /// When `falsy` the label is used.
+ std::string sortText;
+
+ /// A string that should be used when filtering a set of completion items.
+ /// When `falsy` the label is used.
+ std::string filterText;
+
+ /// A string that should be inserted to a document when selecting this
+ /// completion. When `falsy` the label is used.
+ std::string insertText;
+
+ /// The format of the insert text. The format applies to both the `insertText`
+ /// property and the `newText` property of a provided `textEdit`.
+ InsertTextFormat insertTextFormat = InsertTextFormat::Missing;
+
+ /// An edit which is applied to a document when selecting this completion.
+ /// When an edit is provided `insertText` is ignored.
+ ///
+ /// Note: The range of the edit must be a single line range and it must
+ /// contain the position at which completion has been requested.
+ Optional<TextEdit> textEdit;
+
+ /// An optional array of additional text edits that are applied when selecting
+ /// this completion. Edits must not overlap with the main edit nor with
+ /// themselves.
+ std::vector<TextEdit> additionalTextEdits;
+
+ /// Indicates if this item is deprecated.
+ bool deprecated = false;
+};
+
+/// Add support for JSON serialization.
+llvm::json::Value toJSON(const CompletionItem &value);
+raw_ostream &operator<<(raw_ostream &os, const CompletionItem &value);
+bool operator<(const CompletionItem &lhs, const CompletionItem &rhs);
+
+//===----------------------------------------------------------------------===//
+// CompletionList
+//===----------------------------------------------------------------------===//
+
+/// Represents a collection of completion items to be presented in the editor.
+struct CompletionList {
+ /// The list is not complete. Further typing should result in recomputing the
+ /// list.
+ bool isIncomplete = false;
+
+ /// The completion items.
+ std::vector<CompletionItem> items;
+};
+
+/// Add support for JSON serialization.
+llvm::json::Value toJSON(const CompletionList &value);
+
+//===----------------------------------------------------------------------===//
+// CompletionContext
+//===----------------------------------------------------------------------===//
+
+enum class CompletionTriggerKind {
+ /// Completion was triggered by typing an identifier (24x7 code
+ /// complete), manual invocation (e.g Ctrl+Space) or via API.
+ Invoked = 1,
+
+ /// Completion was triggered by a trigger character specified by
+ /// the `triggerCharacters` properties of the `CompletionRegistrationOptions`.
+ TriggerCharacter = 2,
+
+ /// Completion was re-triggered as the current completion list is incomplete.
+ TriggerTriggerForIncompleteCompletions = 3
+};
+
+struct CompletionContext {
+ /// How the completion was triggered.
+ CompletionTriggerKind triggerKind = CompletionTriggerKind::Invoked;
+
+ /// The trigger character (a single character) that has trigger code complete.
+ /// Is undefined if `triggerKind !== CompletionTriggerKind.TriggerCharacter`
+ std::string triggerCharacter;
+};
+
+/// Add support for JSON serialization.
+bool fromJSON(const llvm::json::Value &value, CompletionContext &result,
+ llvm::json::Path path);
+
+//===----------------------------------------------------------------------===//
+// CompletionParams
+//===----------------------------------------------------------------------===//
+
+struct CompletionParams : TextDocumentPositionParams {
+ CompletionContext context;
+};
+
+/// Add support for JSON serialization.
+bool fromJSON(const llvm::json::Value &value, CompletionParams &result,
+ llvm::json::Path path);
+
} // namespace lsp
} // namespace mlir
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
index 214e346ed0128..57280e0bdd171 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
@@ -64,6 +64,12 @@ struct LSPServer {
void onDocumentSymbol(const DocumentSymbolParams ¶ms,
Callback<std::vector<DocumentSymbol>> reply);
+ //===--------------------------------------------------------------------===//
+ // Code Completion
+
+ void onCompletion(const CompletionParams ¶ms,
+ Callback<CompletionList> reply);
+
//===--------------------------------------------------------------------===//
// Fields
//===--------------------------------------------------------------------===//
@@ -94,6 +100,15 @@ void LSPServer::onInitialize(const InitializeParams ¶ms,
{"change", (int)TextDocumentSyncKind::Full},
{"save", true},
}},
+ {"completionProvider",
+ llvm::json::Object{
+ {"allCommitCharacters",
+ {" ", "\t", "(", ")", "[", "]", "{", "}", "<",
+ ">", ":", ";", ",", "+", "-", "/", "*", "%",
+ "^", "&", "#", "?", ".", "=", "\"", "'", "|"}},
+ {"resolveProvider", false},
+ {"triggerCharacters", {".", ">", "(", "{", ",", "<", ":", "[", " "}},
+ }},
{"definitionProvider", true},
{"referencesProvider", true},
{"hoverProvider", true},
@@ -186,6 +201,14 @@ void LSPServer::onDocumentSymbol(const DocumentSymbolParams ¶ms,
reply(std::move(symbols));
}
+//===----------------------------------------------------------------------===//
+// Code Completion
+
+void LSPServer::onCompletion(const CompletionParams ¶ms,
+ Callback<CompletionList> reply) {
+ reply(server.getCodeCompletion(params.textDocument.uri, params.position));
+}
+
//===----------------------------------------------------------------------===//
// Entry Point
//===----------------------------------------------------------------------===//
@@ -222,6 +245,10 @@ LogicalResult mlir::lsp::runPdllLSPServer(PDLLServer &server,
messageHandler.method("textDocument/documentSymbol", &lspServer,
&LSPServer::onDocumentSymbol);
+ // Code Completion
+ messageHandler.method("textDocument/completion", &lspServer,
+ &LSPServer::onCompletion);
+
// Diagnostics
lspServer.publishDiagnostics =
messageHandler.outgoingNotification<PublishDiagnosticsParams>(
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
index cc8c22f49ab23..495df9019a174 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
@@ -17,6 +17,7 @@
#include "mlir/Tools/PDLL/ODS/Context.h"
#include "mlir/Tools/PDLL/ODS/Dialect.h"
#include "mlir/Tools/PDLL/ODS/Operation.h"
+#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
#include "mlir/Tools/PDLL/Parser/Parser.h"
#include "llvm/ADT/IntervalMap.h"
#include "llvm/ADT/StringMap.h"
@@ -277,6 +278,13 @@ struct PDLDocument {
void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
+ //===--------------------------------------------------------------------===//
+ // Code Completion
+ //===--------------------------------------------------------------------===//
+
+ lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
+ const lsp::Position &completePos);
+
//===--------------------------------------------------------------------===//
// Fields
//===--------------------------------------------------------------------===//
@@ -546,6 +554,279 @@ void PDLDocument::findDocumentSymbols(
}
}
+//===----------------------------------------------------------------------===//
+// PDLDocument: Code Completion
+//===----------------------------------------------------------------------===//
+
+namespace {
+class LSPCodeCompleteContext : public CodeCompleteContext {
+public:
+ LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
+ ods::Context &odsContext)
+ : CodeCompleteContext(completeLoc), completionList(completionList),
+ odsContext(odsContext) {}
+
+ void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final {
+ ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes();
+ ArrayRef<StringRef> elementNames = tupleType.getElementNames();
+ for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
+ // Push back a completion item that uses the result index.
+ lsp::CompletionItem item;
+ item.label = llvm::formatv("{0} (field #{0})", i).str();
+ item.insertText = Twine(i).str();
+ item.filterText = item.sortText = item.insertText;
+ item.kind = lsp::CompletionItemKind::Field;
+ item.detail = llvm::formatv("{0}: {1}", i, elementTypes[i]);
+ item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+ completionList.items.emplace_back(item);
+
+ // If the element has a name, push back a completion item with that name.
+ if (!elementNames[i].empty()) {
+ item.label =
+ llvm::formatv("{1} (field #{0})", i, elementNames[i]).str();
+ item.filterText = item.label;
+ item.insertText = elementNames[i].str();
+ completionList.items.emplace_back(item);
+ }
+ }
+ }
+
+ void codeCompleteOperationMemberAccess(ast::OperationType opType) final {
+ Optional<StringRef> opName = opType.getName();
+ const ods::Operation *odsOp =
+ opName ? odsContext.lookupOperation(*opName) : nullptr;
+ if (!odsOp)
+ return;
+
+ ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
+ for (const auto &it : llvm::enumerate(results)) {
+ const ods::OperandOrResult &result = it.value();
+ const ods::TypeConstraint &constraint = result.getConstraint();
+
+ // Push back a completion item that uses the result index.
+ lsp::CompletionItem item;
+ item.label = llvm::formatv("{0} (field #{0})", it.index()).str();
+ item.insertText = Twine(it.index()).str();
+ item.filterText = item.sortText = item.insertText;
+ item.kind = lsp::CompletionItemKind::Field;
+ switch (result.getVariableLengthKind()) {
+ case ods::VariableLengthKind::Single:
+ item.detail = llvm::formatv("{0}: Value", it.index()).str();
+ break;
+ case ods::VariableLengthKind::Optional:
+ item.detail = llvm::formatv("{0}: Value?", it.index()).str();
+ break;
+ case ods::VariableLengthKind::Variadic:
+ item.detail = llvm::formatv("{0}: ValueRange", it.index()).str();
+ break;
+ }
+ item.documentation = lsp::MarkupContent{
+ lsp::MarkupKind::Markdown,
+ llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
+ constraint.getCppClass())
+ .str()};
+ item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+ completionList.items.emplace_back(item);
+
+ // If the result has a name, push back a completion item with the result
+ // name.
+ if (!result.getName().empty()) {
+ item.label =
+ llvm::formatv("{1} (field #{0})", it.index(), result.getName())
+ .str();
+ item.filterText = item.label;
+ item.insertText = result.getName().str();
+ completionList.items.emplace_back(item);
+ }
+ }
+ }
+
+ void codeCompleteOperationAttributeName(StringRef opName) final {
+ const ods::Operation *odsOp = odsContext.lookupOperation(opName);
+ if (!odsOp)
+ return;
+
+ for (const ods::Attribute &attr : odsOp->getAttributes()) {
+ const ods::AttributeConstraint &constraint = attr.getConstraint();
+
+ lsp::CompletionItem item;
+ item.label = attr.getName().str();
+ item.kind = lsp::CompletionItemKind::Field;
+ item.detail = attr.isOptional() ? "optional" : "";
+ item.documentation = lsp::MarkupContent{
+ lsp::MarkupKind::Markdown,
+ llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
+ constraint.getCppClass())
+ .str()};
+ item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+ completionList.items.emplace_back(item);
+ }
+ }
+
+ void codeCompleteConstraintName(ast::Type currentType,
+ bool allowNonCoreConstraints,
+ bool allowInlineTypeConstraints,
+ const ast::DeclScope *scope) final {
+ auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
+ StringRef snippetText = "") {
+ lsp::CompletionItem item;
+ item.label = constraint.str();
+ item.kind = lsp::CompletionItemKind::Class;
+ item.detail = (constraint + " constraint").str();
+ item.documentation = lsp::MarkupContent{
+ lsp::MarkupKind::Markdown,
+ ("A single entity core constraint of type `" + mlirType + "`").str()};
+ item.sortText = "0";
+ item.insertText = snippetText.str();
+ item.insertTextFormat = snippetText.empty()
+ ? lsp::InsertTextFormat::PlainText
+ : lsp::InsertTextFormat::Snippet;
+ completionList.items.emplace_back(item);
+ };
+
+ // Insert completions for the core constraints. Some core constraints have
+ // additional characteristics, so we may add then even if a type has been
+ // inferred.
+ if (!currentType) {
+ addCoreConstraint("Attr", "mlir::Attribute");
+ addCoreConstraint("Op", "mlir::Operation *");
+ addCoreConstraint("Value", "mlir::Value");
+ addCoreConstraint("ValueRange", "mlir::ValueRange");
+ addCoreConstraint("Type", "mlir::Type");
+ addCoreConstraint("TypeRange", "mlir::TypeRange");
+ }
+ if (allowInlineTypeConstraints) {
+ /// Attr<Type>.
+ if (!currentType || currentType.isa<ast::AttributeType>())
+ addCoreConstraint("Attr<type>", "mlir::Attribute", "Attr<$1>");
+ /// Value<Type>.
+ if (!currentType || currentType.isa<ast::ValueType>())
+ addCoreConstraint("Value<type>", "mlir::Value", "Value<$1>");
+ /// ValueRange<TypeRange>.
+ if (!currentType || currentType.isa<ast::ValueRangeType>())
+ addCoreConstraint("ValueRange<type>", "mlir::ValueRange",
+ "ValueRange<$1>");
+ }
+
+ // If a scope was provided, check it for potential constraints.
+ while (scope) {
+ for (const ast::Decl *decl : scope->getDecls()) {
+ if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
+ if (!allowNonCoreConstraints)
+ continue;
+
+ lsp::CompletionItem item;
+ item.label = cst->getName().getName().str();
+ item.kind = lsp::CompletionItemKind::Interface;
+ item.sortText = "2_" + item.label;
+
+ // Skip constraints that are not single-arg. We currently only
+ // complete variable constraints.
+ if (cst->getInputs().size() != 1)
+ continue;
+
+ // Ensure the input type matched the given type.
+ ast::Type constraintType = cst->getInputs()[0]->getType();
+ if (currentType && !currentType.refineWith(constraintType))
+ continue;
+
+ // Format the constraint signature.
+ {
+ llvm::raw_string_ostream strOS(item.detail);
+ strOS << "(";
+ llvm::interleaveComma(
+ cst->getInputs(), strOS, [&](const ast::VariableDecl *var) {
+ strOS << var->getName().getName() << ": " << var->getType();
+ });
+ strOS << ") -> " << cst->getResultType();
+ }
+
+ completionList.items.emplace_back(item);
+ }
+ }
+
+ scope = scope->getParentScope();
+ }
+ }
+
+ void codeCompleteDialectName() final {
+ // Code complete known dialects.
+ for (const ods::Dialect &dialect : odsContext.getDialects()) {
+ lsp::CompletionItem item;
+ item.label = dialect.getName().str();
+ item.kind = lsp::CompletionItemKind::Class;
+ item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+ completionList.items.emplace_back(item);
+ }
+ }
+
+ void codeCompleteOperationName(StringRef dialectName) final {
+ const ods::Dialect *dialect = odsContext.lookupDialect(dialectName);
+ if (!dialect)
+ return;
+
+ for (const auto &it : dialect->getOperations()) {
+ const ods::Operation &op = *it.second;
+
+ lsp::CompletionItem item;
+ item.label = op.getName().drop_front(dialectName.size() + 1).str();
+ item.kind = lsp::CompletionItemKind::Field;
+ item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+ completionList.items.emplace_back(item);
+ }
+ }
+
+ void codeCompletePatternMetadata() final {
+ auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
+ StringRef snippetText = "") {
+ lsp::CompletionItem item;
+ item.label = constraint.str();
+ item.kind = lsp::CompletionItemKind::Class;
+ item.detail = "pattern metadata";
+ item.documentation =
+ lsp::MarkupContent{lsp::MarkupKind::Markdown, desc.str()};
+ item.insertText = snippetText.str();
+ item.insertTextFormat = snippetText.empty()
+ ? lsp::InsertTextFormat::PlainText
+ : lsp::InsertTextFormat::Snippet;
+ completionList.items.emplace_back(item);
+ };
+
+ addSimpleConstraint("benefit", "The `benefit` of matching the pattern.",
+ "benefit($1)");
+ addSimpleConstraint("recursion",
+ "The pattern properly handles recursive application.");
+ }
+
+private:
+ lsp::CompletionList &completionList;
+ ods::Context &odsContext;
+};
+} // namespace
+
+lsp::CompletionList
+PDLDocument::getCodeCompletion(const lsp::URIForFile &uri,
+ const lsp::Position &completePos) {
+ SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
+ if (!posLoc.isValid())
+ return lsp::CompletionList();
+
+ // Adjust the position one further to after the completion trigger token.
+ posLoc = SMLoc::getFromPointer(posLoc.getPointer() + 1);
+
+ // To perform code completion, we run another parse of the module with the
+ // code completion context provided.
+ ods::Context tmpODSContext;
+ lsp::CompletionList completionList;
+ LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
+ tmpODSContext);
+
+ ast::Context tmpContext(tmpODSContext);
+ (void)parsePDLAST(tmpContext, sourceMgr, &lspCompleteContext);
+
+ return completionList;
+}
+
//===----------------------------------------------------------------------===//
// PDLTextFileChunk
//===----------------------------------------------------------------------===//
@@ -600,6 +881,8 @@ class PDLTextFile {
Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
lsp::Position hoverPos);
void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
+ lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
+ lsp::Position completePos);
private:
/// Find the PDL document that contains the given position, and update the
@@ -737,6 +1020,22 @@ void PDLTextFile::findDocumentSymbols(
}
}
+lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri,
+ lsp::Position completePos) {
+ PDLTextFileChunk &chunk = getChunkFor(completePos);
+ lsp::CompletionList completionList =
+ chunk.document.getCodeCompletion(uri, completePos);
+
+ // Adjust any completion locations.
+ for (lsp::CompletionItem &item : completionList.items) {
+ if (item.textEdit)
+ chunk.adjustLocForChunkOffset(item.textEdit->range);
+ for (lsp::TextEdit &edit : item.additionalTextEdits)
+ chunk.adjustLocForChunkOffset(edit.range);
+ }
+ return completionList;
+}
+
PDLTextFileChunk &PDLTextFile::getChunkFor(lsp::Position &pos) {
if (chunks.size() == 1)
return *chunks.front();
@@ -815,3 +1114,12 @@ void lsp::PDLLServer::findDocumentSymbols(
if (fileIt != impl->files.end())
fileIt->second->findDocumentSymbols(symbols);
}
+
+lsp::CompletionList
+lsp::PDLLServer::getCodeCompletion(const URIForFile &uri,
+ const Position &completePos) {
+ auto fileIt = impl->files.find(uri.file());
+ if (fileIt != impl->files.end())
+ return fileIt->second->getCodeCompletion(uri, completePos);
+ return CompletionList();
+}
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h
index 1a647f18db125..5f593204bab24 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h
@@ -15,6 +15,7 @@
namespace mlir {
namespace lsp {
struct Diagnostic;
+struct CompletionList;
struct DocumentSymbol;
struct Hover;
struct Location;
@@ -57,6 +58,10 @@ class PDLLServer {
void findDocumentSymbols(const URIForFile &uri,
std::vector<DocumentSymbol> &symbols);
+ /// Get the code completion list for the position within the given file.
+ CompletionList getCodeCompletion(const URIForFile &uri,
+ const Position &completePos);
+
private:
struct Impl;
diff --git a/mlir/test/mlir-pdll-lsp-server/completion.test b/mlir/test/mlir-pdll-lsp-server/completion.test
new file mode 100644
index 0000000000000..a8d80d8f5ff59
--- /dev/null
+++ b/mlir/test/mlir-pdll-lsp-server/completion.test
@@ -0,0 +1,205 @@
+// RUN: mlir-pdll-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s
+{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"pdll","capabilities":{},"trace":"off"}}
+// -----
+{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
+ "uri":"test:///foo.pdll",
+ "languageId":"pdll",
+ "version":1,
+ "text":"Constraint ValueCst(value: Value);\nConstraint Cst();\nPattern FooPattern with benefit(1) {\nlet tuple = (value1 = _: Op, _: Op<test.op>);\nerase tuple.value1;\n}"
+}}}
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+ "textDocument":{"uri":"test:///foo.pdll"},
+ "position":{"line":4,"character":12}
+}}
+// CHECK: "id": 1
+// CHECK-NEXT: "jsonrpc": "2.0",
+// CHECK-NEXT: "result": {
+// CHECK-NEXT: "isIncomplete": false,
+// CHECK-NEXT: "items": [
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "0: Op",
+// CHECK-NEXT: "filterText": "0",
+// CHECK-NEXT: "insertText": "0",
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 5,
+// CHECK-NEXT: "label": "0 (field #0)",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "0: Op",
+// CHECK-NEXT: "filterText": "value1 (field #0)",
+// CHECK-NEXT: "insertText": "value1",
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 5,
+// CHECK-NEXT: "label": "value1 (field #0)",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "1: Op<test.op>",
+// CHECK-NEXT: "filterText": "1",
+// CHECK-NEXT: "insertText": "1",
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 5,
+// CHECK-NEXT: "label": "1 (field #1)",
+// CHECK-NEXT: "sortText": "1"
+// CHECK-NEXT: }
+// CHECK-NEXT: ]
+// CHECK-NEXT: }
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+ "textDocument":{"uri":"test:///foo.pdll"},
+ "position":{"line":2,"character":23}
+}}
+// CHECK: "id": 1
+// CHECK-NEXT: "jsonrpc": "2.0",
+// CHECK-NEXT: "result": {
+// CHECK-NEXT: "isIncomplete": false,
+// CHECK-NEXT: "items": [
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "pattern metadata",
+// CHECK-NEXT: "documentation": {
+// CHECK-NEXT: "kind": "markdown",
+// CHECK-NEXT: "value": "The `benefit` of matching the pattern."
+// CHECK-NEXT: },
+// CHECK-NEXT: "insertText": "benefit($1)",
+// CHECK-NEXT: "insertTextFormat": 2,
+// CHECK-NEXT: "kind": 7,
+// CHECK-NEXT: "label": "benefit"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "pattern metadata",
+// CHECK-NEXT: "documentation": {
+// CHECK-NEXT: "kind": "markdown",
+// CHECK-NEXT: "value": "The pattern properly handles recursive application."
+// CHECK-NEXT: },
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 7,
+// CHECK-NEXT: "label": "recursion"
+// CHECK-NEXT: }
+// CHECK-NEXT: ]
+// CHECK-NEXT: }
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+ "textDocument":{"uri":"test:///foo.pdll"},
+ "position":{"line":3,"character":24}
+}}
+// CHECK: "id": 1
+// CHECK-NEXT: "jsonrpc": "2.0",
+// CHECK-NEXT: "result": {
+// CHECK-NEXT: "isIncomplete": false,
+// CHECK-NEXT: "items": [
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "Attr constraint",
+// CHECK-NEXT: "documentation": {
+// CHECK-NEXT: "kind": "markdown",
+// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::Attribute`"
+// CHECK-NEXT: },
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 7,
+// CHECK-NEXT: "label": "Attr",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "Op constraint",
+// CHECK-NEXT: "documentation": {
+// CHECK-NEXT: "kind": "markdown",
+// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::Operation *`"
+// CHECK-NEXT: },
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 7,
+// CHECK-NEXT: "label": "Op",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "Value constraint",
+// CHECK-NEXT: "documentation": {
+// CHECK-NEXT: "kind": "markdown",
+// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::Value`"
+// CHECK-NEXT: },
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 7,
+// CHECK-NEXT: "label": "Value",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "ValueRange constraint",
+// CHECK-NEXT: "documentation": {
+// CHECK-NEXT: "kind": "markdown",
+// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::ValueRange`"
+// CHECK-NEXT: },
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 7,
+// CHECK-NEXT: "label": "ValueRange",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "Type constraint",
+// CHECK-NEXT: "documentation": {
+// CHECK-NEXT: "kind": "markdown",
+// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::Type`"
+// CHECK-NEXT: },
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 7,
+// CHECK-NEXT: "label": "Type",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "TypeRange constraint",
+// CHECK-NEXT: "documentation": {
+// CHECK-NEXT: "kind": "markdown",
+// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::TypeRange`"
+// CHECK-NEXT: },
+// CHECK-NEXT: "insertTextFormat": 1,
+// CHECK-NEXT: "kind": 7,
+// CHECK-NEXT: "label": "TypeRange",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "Attr<type> constraint",
+// CHECK-NEXT: "documentation": {
+// CHECK-NEXT: "kind": "markdown",
+// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::Attribute`"
+// CHECK-NEXT: },
+// CHECK-NEXT: "insertText": "Attr<$1>",
+// CHECK-NEXT: "insertTextFormat": 2,
+// CHECK-NEXT: "kind": 7,
+// CHECK-NEXT: "label": "Attr<type>",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "Value<type> constraint",
+// CHECK-NEXT: "documentation": {
+// CHECK-NEXT: "kind": "markdown",
+// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::Value`"
+// CHECK-NEXT: },
+// CHECK-NEXT: "insertText": "Value<$1>",
+// CHECK-NEXT: "insertTextFormat": 2,
+// CHECK-NEXT: "kind": 7,
+// CHECK-NEXT: "label": "Value<type>",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "ValueRange<type> constraint",
+// CHECK-NEXT: "documentation": {
+// CHECK-NEXT: "kind": "markdown",
+// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::ValueRange`"
+// CHECK-NEXT: },
+// CHECK-NEXT: "insertText": "ValueRange<$1>",
+// CHECK-NEXT: "insertTextFormat": 2,
+// CHECK-NEXT: "kind": 7,
+// CHECK-NEXT: "label": "ValueRange<type>",
+// CHECK-NEXT: "sortText": "0"
+// CHECK-NEXT: },
+// CHECK-NEXT: {
+// CHECK-NEXT: "detail": "(value: Value) -> Tuple<>",
+// CHECK-NEXT: "kind": 8,
+// CHECK-NEXT: "label": "ValueCst",
+// CHECK-NEXT: "sortText": "2_ValueCst"
+// CHECK-NEXT: }
+// CHECK-NEXT: ]
+// CHECK-NEXT: }
+// -----
+{"jsonrpc":"2.0","id":3,"method":"shutdown"}
+// -----
+{"jsonrpc":"2.0","method":"exit"}
diff --git a/mlir/test/mlir-pdll-lsp-server/initialize-params.test b/mlir/test/mlir-pdll-lsp-server/initialize-params.test
index d2af6f514fe54..2d40edfaa7bc5 100644
--- a/mlir/test/mlir-pdll-lsp-server/initialize-params.test
+++ b/mlir/test/mlir-pdll-lsp-server/initialize-params.test
@@ -5,6 +5,13 @@
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "capabilities": {
+// CHECK-NEXT: "completionProvider": {
+// CHECK-NEXT: "allCommitCharacters": [
+// CHECK: ],
+// CHECK-NEXT: "resolveProvider": false,
+// CHECK-NEXT: "triggerCharacters": [
+// CHECK: ]
+// CHECK-NEXT: },
// CHECK-NEXT: "definitionProvider": true,
// CHECK-NEXT: "documentSymbolProvider": true,
// CHECK-NEXT: "hoverProvider": true,
More information about the Mlir-commits
mailing list