[Mlir-commits] [mlir] 008de48 - [mlir][PDLL] Add code completion to the PDLL language server

River Riddle llvmlistbot at llvm.org
Sat Mar 19 13:29:34 PDT 2022


Author: River Riddle
Date: 2022-03-19T13:28:24-07:00
New Revision: 008de486f706ef25a66d4384c2c3af1ed86e680e

URL: https://github.com/llvm/llvm-project/commit/008de486f706ef25a66d4384c2c3af1ed86e680e
DIFF: https://github.com/llvm/llvm-project/commit/008de486f706ef25a66d4384c2c3af1ed86e680e.diff

LOG: [mlir][PDLL] Add code completion to the PDLL language server

This commit adds code completion support to the language server,
and initially supports providing completions for: Member access,
attributes/constraint/dialect/operation names, and pattern metadata.

Differential Revision: https://reviews.llvm.org/D121544

Added: 
    mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h
    mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp
    mlir/test/mlir-pdll-lsp-server/completion.test

Modified: 
    mlir/include/mlir/Tools/PDLL/ODS/Context.h
    mlir/include/mlir/Tools/PDLL/Parser/Parser.h
    mlir/lib/Tools/PDLL/Parser/CMakeLists.txt
    mlir/lib/Tools/PDLL/Parser/Lexer.cpp
    mlir/lib/Tools/PDLL/Parser/Lexer.h
    mlir/lib/Tools/PDLL/Parser/Parser.cpp
    mlir/lib/Tools/lsp-server-support/Protocol.cpp
    mlir/lib/Tools/lsp-server-support/Protocol.h
    mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
    mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
    mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h
    mlir/test/mlir-pdll-lsp-server/initialize-params.test

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Tools/PDLL/ODS/Context.h b/mlir/include/mlir/Tools/PDLL/ODS/Context.h
index d0955ab62d8e7..ea3751eac3541 100644
--- a/mlir/include/mlir/Tools/PDLL/ODS/Context.h
+++ b/mlir/include/mlir/Tools/PDLL/ODS/Context.h
@@ -53,6 +53,11 @@ class Context {
   /// with that name was inserted.
   const Dialect *lookupDialect(StringRef name) const;
 
+  /// Return a range of all of the registered dialects.
+  auto getDialects() const {
+    return llvm::make_pointee_range(llvm::make_second_range(dialects));
+  }
+
   /// Insert a new operation with the context. Returns the inserted operation,
   /// and a boolean indicating if the operation newly inserted (false if the
   /// operation already existed).

diff  --git a/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h b/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h
new file mode 100644
index 0000000000000..f97310b3ecbeb
--- /dev/null
+++ b/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h
@@ -0,0 +1,82 @@
+//===- CodeComplete.h - PDLL Frontend CodeComplete Context ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_PDLL_PARSER_CODECOMPLETE_H_
+#define MLIR_TOOLS_PDLL_PARSER_CODECOMPLETE_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir {
+namespace pdll {
+namespace ast {
+class CallableDecl;
+class DeclScope;
+class Expr;
+class OperationType;
+class TupleType;
+class Type;
+class VariableDecl;
+} // namespace ast
+
+/// This class provides an abstract interface into the parser for hooking in
+/// code completion events.
+class CodeCompleteContext {
+public:
+  virtual ~CodeCompleteContext();
+
+  /// Return the location used to provide code completion.
+  SMLoc getCodeCompleteLoc() const { return codeCompleteLoc; }
+
+  //===--------------------------------------------------------------------===//
+  // Completion Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// Signal code completion for a member access into the given tuple type.
+  virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType);
+
+  /// Signal code completion for a member access into the given operation type.
+  virtual void codeCompleteOperationMemberAccess(ast::OperationType opType);
+
+  /// Signal code completion for a member access into the given operation type.
+  virtual void codeCompleteOperationAttributeName(StringRef opName) {}
+
+  /// Signal code completion for a constraint name with an optional decl scope.
+  /// `currentType` is the current type of the variable that will use the
+  /// constraint, or nullptr if a type is unknown. `allowNonCoreConstraints`
+  /// indicates if user defined constraints are allowed in the completion
+  /// results. `allowInlineTypeConstraints` enables inline type constraints for
+  /// Attr/Value/ValueRange.
+  virtual void codeCompleteConstraintName(ast::Type currentType,
+                                          bool allowNonCoreConstraints,
+                                          bool allowInlineTypeConstraints,
+                                          const ast::DeclScope *scope);
+
+  /// Signal code completion for a dialect name.
+  virtual void codeCompleteDialectName() {}
+
+  /// Signal code completion for an operation name in the given dialect.
+  virtual void codeCompleteOperationName(StringRef dialectName) {}
+
+  /// Signal code completion for Pattern metadata.
+  virtual void codeCompletePatternMetadata() {}
+
+protected:
+  /// Create a new code completion context with the given code complete
+  /// location.
+  explicit CodeCompleteContext(SMLoc codeCompleteLoc)
+      : codeCompleteLoc(codeCompleteLoc) {}
+
+private:
+  /// The location used to code complete.
+  SMLoc codeCompleteLoc;
+};
+} // namespace pdll
+} // namespace mlir
+
+#endif // MLIR_TOOLS_PDLL_PARSER_CODECOMPLETE_H_

diff  --git a/mlir/include/mlir/Tools/PDLL/Parser/Parser.h b/mlir/include/mlir/Tools/PDLL/Parser/Parser.h
index a0b269aeb6eb1..ce5815a478d1d 100644
--- a/mlir/include/mlir/Tools/PDLL/Parser/Parser.h
+++ b/mlir/include/mlir/Tools/PDLL/Parser/Parser.h
@@ -19,14 +19,19 @@ class SourceMgr;
 
 namespace mlir {
 namespace pdll {
+class CodeCompleteContext;
+
 namespace ast {
 class Context;
 class Module;
 } // namespace ast
 
-/// Parse an AST module from the main file of the given source manager.
-FailureOr<ast::Module *> parsePDLAST(ast::Context &ctx,
-                                     llvm::SourceMgr &sourceMgr);
+/// Parse an AST module from the main file of the given source manager. An
+/// optional code completion context may be provided to receive code completion
+/// suggestions. If a completion is hit, this method returns a failure.
+FailureOr<ast::Module *>
+parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
+            CodeCompleteContext *codeCompleteContext = nullptr);
 } // namespace pdll
 } // namespace mlir
 

diff  --git a/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt b/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt
index f705214eaca60..7953677d1957e 100644
--- a/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt
+++ b/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt
@@ -1,5 +1,6 @@
 
 llvm_add_library(MLIRPDLLParser STATIC
+  CodeComplete.cpp
   Lexer.cpp
   Parser.cpp
   

diff  --git a/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp b/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp
new file mode 100644
index 0000000000000..acc2ca84037dc
--- /dev/null
+++ b/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp
@@ -0,0 +1,28 @@
+//===- CodeComplete.cpp ---------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
+#include "mlir/Tools/PDLL/AST/Types.h"
+
+using namespace mlir;
+using namespace mlir::pdll;
+
+//===----------------------------------------------------------------------===//
+// CodeCompleteContext
+//===----------------------------------------------------------------------===//
+
+CodeCompleteContext::~CodeCompleteContext() = default;
+
+void CodeCompleteContext::codeCompleteTupleMemberAccess(
+    ast::TupleType tupleType) {}
+void CodeCompleteContext::codeCompleteOperationMemberAccess(
+    ast::OperationType opType) {}
+
+void CodeCompleteContext::codeCompleteConstraintName(
+    ast::Type currentType, bool allowNonCoreConstraints,
+    bool allowInlineTypeConstraints, const ast::DeclScope *scope) {}

diff  --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
index 61c5783e24545..efae2b46af8d0 100644
--- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
@@ -9,6 +9,7 @@
 #include "Lexer.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
+#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/SourceMgr.h"
@@ -67,12 +68,20 @@ std::string Token::getStringValue() const {
 // Lexer
 //===----------------------------------------------------------------------===//
 
-Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine)
-    : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false) {
+Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
+             CodeCompleteContext *codeCompleteContext)
+    : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false),
+      codeCompletionLocation(nullptr) {
   curBufferID = mgr.getMainFileID();
   curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
   curPtr = curBuffer.begin();
 
+  // Set the code completion location if necessary.
+  if (codeCompleteContext) {
+    codeCompletionLocation =
+        codeCompleteContext->getCodeCompleteLoc().getPointer();
+  }
+
   // If the diag engine has no handler, add a default that emits to the
   // SourceMgr.
   if (!diagEngine.getHandlerFn()) {
@@ -147,6 +156,10 @@ Token Lexer::lexToken() {
   while (true) {
     const char *tokStart = curPtr;
 
+    // Check to see if this token is at the code completion location.
+    if (tokStart == codeCompletionLocation)
+      return formToken(Token::code_complete, tokStart);
+
     // This always consumes at least one character.
     int curChar = getNextChar();
     switch (curChar) {

diff  --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h
index 0109b0da36c79..07eb2e2e4c418 100644
--- a/mlir/lib/Tools/PDLL/Parser/Lexer.h
+++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h
@@ -21,6 +21,8 @@ namespace mlir {
 struct LogicalResult;
 
 namespace pdll {
+class CodeCompleteContext;
+
 namespace ast {
 class DiagnosticEngine;
 } // namespace ast
@@ -35,6 +37,7 @@ class Token {
     // Markers.
     eof,
     error,
+    code_complete,
 
     // Keywords.
     KW_BEGIN,
@@ -162,7 +165,8 @@ class Token {
 
 class Lexer {
 public:
-  Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine);
+  Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
+        CodeCompleteContext *codeCompleteContext);
   ~Lexer();
 
   /// Return a reference to the source manager used by the lexer.
@@ -215,6 +219,9 @@ class Lexer {
   /// A flag indicating if we added a default diagnostic handler to the provided
   /// diagEngine.
   bool addedHandlerToDiagEngine;
+
+  /// The optional code completion point within the input file.
+  const char *codeCompletionLocation;
 };
 } // namespace pdll
 } // namespace mlir

diff  --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 167aee27f8747..3d1f2dc3ce6fe 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Tools/PDLL/ODS/Constraint.h"
 #include "mlir/Tools/PDLL/ODS/Context.h"
 #include "mlir/Tools/PDLL/ODS/Operation.h"
+#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -41,13 +42,15 @@ using namespace mlir::pdll;
 namespace {
 class Parser {
 public:
-  Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
-      : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
+  Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
+         CodeCompleteContext *codeCompleteContext)
+      : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
         curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)),
         valueRangeTy(ast::ValueRangeType::get(ctx)),
         typeTy(ast::TypeType::get(ctx)),
         typeRangeTy(ast::TypeRangeType::get(ctx)),
-        attrTy(ast::AttributeType::get(ctx)) {}
+        attrTy(ast::AttributeType::get(ctx)),
+        codeCompleteContext(codeCompleteContext) {}
 
   /// Try to parse a new module. Returns nullptr in the case of failure.
   FailureOr<ast::Module *> parseModule();
@@ -142,7 +145,8 @@ class Parser {
   };
 
   FailureOr<ast::Decl *> parseTopLevelDecl();
-  FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
+  FailureOr<ast::NamedAttributeDecl *>
+  parseNamedAttributeDecl(Optional<StringRef> parentOpName);
 
   /// Parse an argument variable as part of the signature of a
   /// UserConstraintDecl or UserRewriteDecl.
@@ -248,10 +252,13 @@ class Parser {
   /// existing constraints that have already been parsed for the same entity
   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
+  /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined
+  /// constraints) may be used with the variable.
   FailureOr<ast::ConstraintRef>
   parseConstraint(Optional<SMRange> &typeConstraint,
                   ArrayRef<ast::ConstraintRef> existingConstraints,
-                  bool allowInlineTypeConstraints);
+                  bool allowInlineTypeConstraints,
+                  bool allowNonCoreConstraints);
 
   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
   /// argument or result variable. The constraints for these variables do not
@@ -335,10 +342,12 @@ class Parser {
   /// `inferredType` is the type of the variable inferred by the constraints
   /// within the list, and is updated to the most refined type as determined by
   /// the constraints. Returns success if the constraint list is valid, failure
-  /// otherwise.
+  /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
+  /// defined constraints) may be used with the variable.
   LogicalResult
   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
-                              ast::Type &inferredType);
+                              ast::Type &inferredType,
+                              bool allowNonCoreConstraints = true);
   /// Validate a single reference to a constraint. `inferredType` contains the
   /// currently inferred variabled type and is refined within the type defined
   /// by the constraint. Returns success if the constraint is valid, failure
@@ -399,6 +408,23 @@ class Parser {
   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
                     ast::CompoundStmt *rewriteBody);
 
+  //===--------------------------------------------------------------------===//
+  // Code Completion
+  //===--------------------------------------------------------------------===//
+
+  /// The set of various code completion methods. Every completion method
+  /// returns `failure` to stop the parsing process after providing completion
+  /// results.
+
+  LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
+  LogicalResult codeCompleteAttributeName(Optional<StringRef> opName);
+  LogicalResult codeCompleteConstraintName(ast::Type inferredType,
+                                           bool allowNonCoreConstraints,
+                                           bool allowInlineTypeConstraints);
+  LogicalResult codeCompleteDialectName();
+  LogicalResult codeCompleteOperationName(StringRef dialectName);
+  LogicalResult codeCompletePatternMetadata();
+
   //===--------------------------------------------------------------------===//
   // Lexer Utilities
   //===--------------------------------------------------------------------===//
@@ -481,6 +507,9 @@ class Parser {
 
   /// A counter used when naming anonymous constraints and rewrites.
   unsigned anonymousDeclNameCounter = 0;
+
+  /// The optional code completion context.
+  CodeCompleteContext *codeCompleteContext;
 };
 } // namespace
 
@@ -890,7 +919,12 @@ FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
   return decl;
 }
 
-FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
+FailureOr<ast::NamedAttributeDecl *>
+Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) {
+  // Check for name code completion.
+  if (curToken.is(Token::code_complete))
+    return codeCompleteAttributeName(parentOpName);
+
   std::string attrNameStr;
   if (curToken.isString())
     attrNameStr = curToken.getStringValue();
@@ -1380,6 +1414,10 @@ Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
   Optional<SMRange> hasBoundedRecursionLoc;
 
   do {
+    // Handle metadata code completion.
+    if (curToken.is(Token::code_complete))
+      return codeCompletePatternMetadata();
+
     if (curToken.isNot(Token::identifier))
       return emitError("expected pattern metadata identifier");
     StringRef metadataStr = curToken.getSpelling();
@@ -1488,7 +1526,8 @@ LogicalResult Parser::parseVariableDeclConstraintList(
   Optional<SMRange> typeConstraint;
   auto parseSingleConstraint = [&] {
     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
-        typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
+        typeConstraint, constraints, /*allowInlineTypeConstraints=*/true,
+        /*allowNonCoreConstraints=*/true);
     if (failed(constraint))
       return failure();
     constraints.push_back(*constraint);
@@ -1509,7 +1548,8 @@ LogicalResult Parser::parseVariableDeclConstraintList(
 FailureOr<ast::ConstraintRef>
 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
                         ArrayRef<ast::ConstraintRef> existingConstraints,
-                        bool allowInlineTypeConstraints) {
+                        bool allowInlineTypeConstraints,
+                        bool allowNonCoreConstraints) {
   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
     if (!allowInlineTypeConstraints) {
       return emitError(
@@ -1610,6 +1650,17 @@ Parser::parseConstraint(Optional<SMRange> &typeConstraint,
     return emitErrorAndNote(
         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
         "see the definition of `" + constraintName + "` here");
+  }
+    // Handle single entity constraint code completion.
+  case Token::code_complete: {
+    // Try to infer the current type for use by code completion.
+    ast::Type inferredType;
+    if (failed(validateVariableConstraints(existingConstraints, inferredType,
+                                           allowNonCoreConstraints)))
+      return failure();
+
+    return codeCompleteConstraintName(inferredType, allowNonCoreConstraints,
+                                      allowInlineTypeConstraints);
   }
   default:
     break;
@@ -1618,9 +1669,13 @@ Parser::parseConstraint(Optional<SMRange> &typeConstraint,
 }
 
 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
+  // Constraint arguments may apply more complex constraints via the arguments.
+  bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
+
   Optional<SMRange> typeConstraint;
   return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
-                         /*allowInlineTypeConstraints=*/false);
+                         /*allowInlineTypeConstraints=*/false,
+                         allowNonCoreConstraints);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1770,6 +1825,10 @@ FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
   SMRange loc = curToken.getLoc();
   consumeToken(Token::dot);
 
+  // Check for code completion of the member name.
+  if (curToken.is(Token::code_complete))
+    return codeCompleteMemberAccess(parentExpr);
+
   // Parse the member name.
   Token memberNameTok = curToken;
   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
@@ -1784,6 +1843,10 @@ FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
   SMRange loc = curToken.getLoc();
 
+  // Check for code completion for the dialect name.
+  if (curToken.is(Token::code_complete))
+    return codeCompleteDialectName();
+
   // Handle the case of an no operation name.
   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
     if (allowEmptyName)
@@ -1797,6 +1860,10 @@ FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
     return failure();
 
+  // Check for code completion for the operation name.
+  if (curToken.is(Token::code_complete))
+    return codeCompleteOperationName(name);
+
   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
     return emitError("expected operation name after dialect namespace");
 
@@ -1843,6 +1910,7 @@ FailureOr<ast::Expr *> Parser::parseOperationExpr() {
       parseWrappedOperationName(allowEmptyName);
   if (failed(opNameDecl))
     return failure();
+  Optional<StringRef> opName = (*opNameDecl)->getName();
 
   // Functor used to create an implicit range variable, used for implicit "all"
   // operand or results variables.
@@ -1882,7 +1950,8 @@ FailureOr<ast::Expr *> Parser::parseOperationExpr() {
   SmallVector<ast::NamedAttributeDecl *> attributes;
   if (consumeIf(Token::l_brace)) {
     do {
-      FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
+      FailureOr<ast::NamedAttributeDecl *> decl =
+          parseNamedAttributeDecl(opName);
       if (failed(decl))
         return failure();
       attributes.emplace_back(*decl);
@@ -2362,9 +2431,11 @@ Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
 
 LogicalResult
 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
-                                    ast::Type &inferredType) {
+                                    ast::Type &inferredType,
+                                    bool allowNonCoreConstraints) {
   for (const ast::ConstraintRef &ref : constraints)
-    if (failed(validateVariableConstraint(ref, inferredType)))
+    if (failed(validateVariableConstraint(ref, inferredType,
+                                          allowNonCoreConstraints)))
       return failure();
   return success();
 }
@@ -2784,12 +2855,57 @@ Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
 }
 
+//===----------------------------------------------------------------------===//
+// Code Completion
+//===----------------------------------------------------------------------===//
+
+LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
+  ast::Type parentType = parentExpr->getType();
+  if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
+    codeCompleteContext->codeCompleteOperationMemberAccess(opType);
+  else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
+    codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
+  return failure();
+}
+
+LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) {
+  if (opName)
+    codeCompleteContext->codeCompleteOperationAttributeName(*opName);
+  return failure();
+}
+
+LogicalResult
+Parser::codeCompleteConstraintName(ast::Type inferredType,
+                                   bool allowNonCoreConstraints,
+                                   bool allowInlineTypeConstraints) {
+  codeCompleteContext->codeCompleteConstraintName(
+      inferredType, allowNonCoreConstraints, allowInlineTypeConstraints,
+      curDeclScope);
+  return failure();
+}
+
+LogicalResult Parser::codeCompleteDialectName() {
+  codeCompleteContext->codeCompleteDialectName();
+  return failure();
+}
+
+LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
+  codeCompleteContext->codeCompleteOperationName(dialectName);
+  return failure();
+}
+
+LogicalResult Parser::codeCompletePatternMetadata() {
+  codeCompleteContext->codeCompletePatternMetadata();
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // Parser
 //===----------------------------------------------------------------------===//
 
-FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
-                                                 llvm::SourceMgr &sourceMgr) {
-  Parser parser(ctx, sourceMgr);
+FailureOr<ast::Module *>
+mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
+                        CodeCompleteContext *codeCompleteContext) {
+  Parser parser(ctx, sourceMgr, codeCompleteContext);
   return parser.parseModule();
 }

diff  --git a/mlir/lib/Tools/lsp-server-support/Protocol.cpp b/mlir/lib/Tools/lsp-server-support/Protocol.cpp
index 7bce042f8900c..2b9ba3b853f7f 100644
--- a/mlir/lib/Tools/lsp-server-support/Protocol.cpp
+++ b/mlir/lib/Tools/lsp-server-support/Protocol.cpp
@@ -585,3 +585,160 @@ llvm::json::Value mlir::lsp::toJSON(const PublishDiagnosticsParams &params) {
       {"version", params.version},
   };
 }
+
+//===----------------------------------------------------------------------===//
+// TextEdit
+//===----------------------------------------------------------------------===//
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value, TextEdit &result,
+                         llvm::json::Path path) {
+  llvm::json::ObjectMapper o(value, path);
+  return o && o.map("range", result.range) && o.map("newText", result.newText);
+}
+
+llvm::json::Value mlir::lsp::toJSON(const TextEdit &value) {
+  return llvm::json::Object{
+      {"range", value.range},
+      {"newText", value.newText},
+  };
+}
+
+raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const TextEdit &value) {
+  os << value.range << " => \"";
+  llvm::printEscapedString(value.newText, os);
+  return os << '"';
+}
+
+//===----------------------------------------------------------------------===//
+// CompletionItemKind
+//===----------------------------------------------------------------------===//
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value,
+                         CompletionItemKind &result, llvm::json::Path path) {
+  if (Optional<int64_t> intValue = value.getAsInteger()) {
+    if (*intValue < static_cast<int>(CompletionItemKind::Text) ||
+        *intValue > static_cast<int>(CompletionItemKind::TypeParameter))
+      return false;
+    result = static_cast<CompletionItemKind>(*intValue);
+    return true;
+  }
+  return false;
+}
+
+CompletionItemKind mlir::lsp::adjustKindToCapability(
+    CompletionItemKind kind,
+    CompletionItemKindBitset &supportedCompletionItemKinds) {
+  size_t kindVal = static_cast<size_t>(kind);
+  if (kindVal >= kCompletionItemKindMin &&
+      kindVal <= supportedCompletionItemKinds.size() &&
+      supportedCompletionItemKinds[kindVal])
+    return kind;
+
+  // Provide some fall backs for common kinds that are close enough.
+  switch (kind) {
+  case CompletionItemKind::Folder:
+    return CompletionItemKind::File;
+  case CompletionItemKind::EnumMember:
+    return CompletionItemKind::Enum;
+  case CompletionItemKind::Struct:
+    return CompletionItemKind::Class;
+  default:
+    return CompletionItemKind::Text;
+  }
+}
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value,
+                         CompletionItemKindBitset &result,
+                         llvm::json::Path path) {
+  if (const llvm::json::Array *arrayValue = value.getAsArray()) {
+    for (size_t i = 0, e = arrayValue->size(); i < e; ++i) {
+      CompletionItemKind kindOut;
+      if (fromJSON((*arrayValue)[i], kindOut, path.index(i)))
+        result.set(size_t(kindOut));
+    }
+    return true;
+  }
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// CompletionItem
+//===----------------------------------------------------------------------===//
+
+llvm::json::Value mlir::lsp::toJSON(const CompletionItem &value) {
+  assert(!value.label.empty() && "completion item label is required");
+  llvm::json::Object result{{"label", value.label}};
+  if (value.kind != CompletionItemKind::Missing)
+    result["kind"] = static_cast<int>(value.kind);
+  if (!value.detail.empty())
+    result["detail"] = value.detail;
+  if (value.documentation)
+    result["documentation"] = value.documentation;
+  if (!value.sortText.empty())
+    result["sortText"] = value.sortText;
+  if (!value.filterText.empty())
+    result["filterText"] = value.filterText;
+  if (!value.insertText.empty())
+    result["insertText"] = value.insertText;
+  if (value.insertTextFormat != InsertTextFormat::Missing)
+    result["insertTextFormat"] = static_cast<int>(value.insertTextFormat);
+  if (value.textEdit)
+    result["textEdit"] = *value.textEdit;
+  if (!value.additionalTextEdits.empty()) {
+    result["additionalTextEdits"] =
+        llvm::json::Array(value.additionalTextEdits);
+  }
+  if (value.deprecated)
+    result["deprecated"] = value.deprecated;
+  return std::move(result);
+}
+
+raw_ostream &mlir::lsp::operator<<(raw_ostream &os,
+                                   const CompletionItem &value) {
+  return os << value.label << " - " << toJSON(value);
+}
+
+bool mlir::lsp::operator<(const CompletionItem &lhs,
+                          const CompletionItem &rhs) {
+  return (lhs.sortText.empty() ? lhs.label : lhs.sortText) <
+         (rhs.sortText.empty() ? rhs.label : rhs.sortText);
+}
+
+//===----------------------------------------------------------------------===//
+// CompletionList
+//===----------------------------------------------------------------------===//
+
+llvm::json::Value mlir::lsp::toJSON(const CompletionList &value) {
+  return llvm::json::Object{
+      {"isIncomplete", value.isIncomplete},
+      {"items", llvm::json::Array(value.items)},
+  };
+}
+
+//===----------------------------------------------------------------------===//
+// CompletionContext
+//===----------------------------------------------------------------------===//
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value,
+                         CompletionContext &result, llvm::json::Path path) {
+  llvm::json::ObjectMapper o(value, path);
+  int triggerKind;
+  if (!o || !o.map("triggerKind", triggerKind) ||
+      !mapOptOrNull(value, "triggerCharacter", result.triggerCharacter, path))
+    return false;
+  result.triggerKind = static_cast<CompletionTriggerKind>(triggerKind);
+  return true;
+}
+
+//===----------------------------------------------------------------------===//
+// CompletionParams
+//===----------------------------------------------------------------------===//
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value,
+                         CompletionParams &result, llvm::json::Path path) {
+  if (!fromJSON(value, static_cast<TextDocumentPositionParams &>(result), path))
+    return false;
+  if (const llvm::json::Value *context = value.getAsObject()->get("context"))
+    return fromJSON(*context, result.context, path.field("context"));
+  return true;
+}

diff  --git a/mlir/lib/Tools/lsp-server-support/Protocol.h b/mlir/lib/Tools/lsp-server-support/Protocol.h
index 6cb1dc4d50f6c..cc8c52abb104e 100644
--- a/mlir/lib/Tools/lsp-server-support/Protocol.h
+++ b/mlir/lib/Tools/lsp-server-support/Protocol.h
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file contains structs based on the LSP specification at
-// https://github.com/Microsoft/language-server-protocol/blob/main/protocol.md
+// https://microsoft.github.io/language-server-protocol/specification
 //
 // This is not meant to be a complete implementation, new interfaces are added
 // when they're needed.
@@ -666,6 +666,211 @@ struct PublishDiagnosticsParams {
 /// Add support for JSON serialization.
 llvm::json::Value toJSON(const PublishDiagnosticsParams &params);
 
+//===----------------------------------------------------------------------===//
+// TextEdit
+//===----------------------------------------------------------------------===//
+
+struct TextEdit {
+  /// The range of the text document to be manipulated. To insert
+  /// text into a document create a range where start === end.
+  Range range;
+
+  /// The string to be inserted. For delete operations use an
+  /// empty string.
+  std::string newText;
+};
+
+inline bool operator==(const TextEdit &lhs, const TextEdit &rhs) {
+  return std::tie(lhs.newText, lhs.range) == std::tie(rhs.newText, rhs.range);
+}
+
+bool fromJSON(const llvm::json::Value &value, TextEdit &result,
+              llvm::json::Path path);
+llvm::json::Value toJSON(const TextEdit &value);
+raw_ostream &operator<<(raw_ostream &os, const TextEdit &value);
+
+//===----------------------------------------------------------------------===//
+// CompletionItemKind
+//===----------------------------------------------------------------------===//
+
+/// The kind of a completion entry.
+enum class CompletionItemKind {
+  Missing = 0,
+  Text = 1,
+  Method = 2,
+  Function = 3,
+  Constructor = 4,
+  Field = 5,
+  Variable = 6,
+  Class = 7,
+  Interface = 8,
+  Module = 9,
+  Property = 10,
+  Unit = 11,
+  Value = 12,
+  Enum = 13,
+  Keyword = 14,
+  Snippet = 15,
+  Color = 16,
+  File = 17,
+  Reference = 18,
+  Folder = 19,
+  EnumMember = 20,
+  Constant = 21,
+  Struct = 22,
+  Event = 23,
+  Operator = 24,
+  TypeParameter = 25,
+};
+bool fromJSON(const llvm::json::Value &value, CompletionItemKind &result,
+              llvm::json::Path path);
+
+constexpr auto kCompletionItemKindMin =
+    static_cast<size_t>(CompletionItemKind::Text);
+constexpr auto kCompletionItemKindMax =
+    static_cast<size_t>(CompletionItemKind::TypeParameter);
+using CompletionItemKindBitset = std::bitset<kCompletionItemKindMax + 1>;
+bool fromJSON(const llvm::json::Value &value, CompletionItemKindBitset &result,
+              llvm::json::Path path);
+
+CompletionItemKind
+adjustKindToCapability(CompletionItemKind kind,
+                       CompletionItemKindBitset &supportedCompletionItemKinds);
+
+//===----------------------------------------------------------------------===//
+// CompletionItem
+//===----------------------------------------------------------------------===//
+
+/// Defines whether the insert text in a completion item should be interpreted
+/// as plain text or a snippet.
+enum class InsertTextFormat {
+  Missing = 0,
+  /// The primary text to be inserted is treated as a plain string.
+  PlainText = 1,
+  /// The primary text to be inserted is treated as a snippet.
+  ///
+  /// A snippet can define tab stops and placeholders with `$1`, `$2`
+  /// and `${3:foo}`. `$0` defines the final tab stop, it defaults to the end
+  /// of the snippet. Placeholders with equal identifiers are linked, that is
+  /// typing in one will update others too.
+  ///
+  /// See also:
+  /// https//github.com/Microsoft/vscode/blob/master/src/vs/editor/contrib/snippet/common/snippet.md
+  Snippet = 2,
+};
+
+struct CompletionItem {
+  /// The label of this completion item. By default also the text that is
+  /// inserted when selecting this completion.
+  std::string label;
+
+  /// The kind of this completion item. Based of the kind an icon is chosen by
+  /// the editor.
+  CompletionItemKind kind = CompletionItemKind::Missing;
+
+  /// A human-readable string with additional information about this item, like
+  /// type or symbol information.
+  std::string detail;
+
+  /// A human-readable string that represents a doc-comment.
+  Optional<MarkupContent> documentation;
+
+  /// A string that should be used when comparing this item with other items.
+  /// When `falsy` the label is used.
+  std::string sortText;
+
+  /// A string that should be used when filtering a set of completion items.
+  /// When `falsy` the label is used.
+  std::string filterText;
+
+  /// A string that should be inserted to a document when selecting this
+  /// completion. When `falsy` the label is used.
+  std::string insertText;
+
+  /// The format of the insert text. The format applies to both the `insertText`
+  /// property and the `newText` property of a provided `textEdit`.
+  InsertTextFormat insertTextFormat = InsertTextFormat::Missing;
+
+  /// An edit which is applied to a document when selecting this completion.
+  /// When an edit is provided `insertText` is ignored.
+  ///
+  /// Note: The range of the edit must be a single line range and it must
+  /// contain the position at which completion has been requested.
+  Optional<TextEdit> textEdit;
+
+  /// An optional array of additional text edits that are applied when selecting
+  /// this completion. Edits must not overlap with the main edit nor with
+  /// themselves.
+  std::vector<TextEdit> additionalTextEdits;
+
+  /// Indicates if this item is deprecated.
+  bool deprecated = false;
+};
+
+/// Add support for JSON serialization.
+llvm::json::Value toJSON(const CompletionItem &value);
+raw_ostream &operator<<(raw_ostream &os, const CompletionItem &value);
+bool operator<(const CompletionItem &lhs, const CompletionItem &rhs);
+
+//===----------------------------------------------------------------------===//
+// CompletionList
+//===----------------------------------------------------------------------===//
+
+/// Represents a collection of completion items to be presented in the editor.
+struct CompletionList {
+  /// The list is not complete. Further typing should result in recomputing the
+  /// list.
+  bool isIncomplete = false;
+
+  /// The completion items.
+  std::vector<CompletionItem> items;
+};
+
+/// Add support for JSON serialization.
+llvm::json::Value toJSON(const CompletionList &value);
+
+//===----------------------------------------------------------------------===//
+// CompletionContext
+//===----------------------------------------------------------------------===//
+
+enum class CompletionTriggerKind {
+  /// Completion was triggered by typing an identifier (24x7 code
+  /// complete), manual invocation (e.g Ctrl+Space) or via API.
+  Invoked = 1,
+
+  /// Completion was triggered by a trigger character specified by
+  /// the `triggerCharacters` properties of the `CompletionRegistrationOptions`.
+  TriggerCharacter = 2,
+
+  /// Completion was re-triggered as the current completion list is incomplete.
+  TriggerTriggerForIncompleteCompletions = 3
+};
+
+struct CompletionContext {
+  /// How the completion was triggered.
+  CompletionTriggerKind triggerKind = CompletionTriggerKind::Invoked;
+
+  /// The trigger character (a single character) that has trigger code complete.
+  /// Is undefined if `triggerKind !== CompletionTriggerKind.TriggerCharacter`
+  std::string triggerCharacter;
+};
+
+/// Add support for JSON serialization.
+bool fromJSON(const llvm::json::Value &value, CompletionContext &result,
+              llvm::json::Path path);
+
+//===----------------------------------------------------------------------===//
+// CompletionParams
+//===----------------------------------------------------------------------===//
+
+struct CompletionParams : TextDocumentPositionParams {
+  CompletionContext context;
+};
+
+/// Add support for JSON serialization.
+bool fromJSON(const llvm::json::Value &value, CompletionParams &result,
+              llvm::json::Path path);
+
 } // namespace lsp
 } // namespace mlir
 

diff  --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
index 214e346ed0128..57280e0bdd171 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
@@ -64,6 +64,12 @@ struct LSPServer {
   void onDocumentSymbol(const DocumentSymbolParams &params,
                         Callback<std::vector<DocumentSymbol>> reply);
 
+  //===--------------------------------------------------------------------===//
+  // Code Completion
+
+  void onCompletion(const CompletionParams &params,
+                    Callback<CompletionList> reply);
+
   //===--------------------------------------------------------------------===//
   // Fields
   //===--------------------------------------------------------------------===//
@@ -94,6 +100,15 @@ void LSPServer::onInitialize(const InitializeParams &params,
            {"change", (int)TextDocumentSyncKind::Full},
            {"save", true},
        }},
+      {"completionProvider",
+       llvm::json::Object{
+           {"allCommitCharacters",
+            {" ", "\t", "(", ")", "[", "]", "{",  "}", "<",
+             ">", ":",  ";", ",", "+", "-", "/",  "*", "%",
+             "^", "&",  "#", "?", ".", "=", "\"", "'", "|"}},
+           {"resolveProvider", false},
+           {"triggerCharacters", {".", ">", "(", "{", ",", "<", ":", "[", " "}},
+       }},
       {"definitionProvider", true},
       {"referencesProvider", true},
       {"hoverProvider", true},
@@ -186,6 +201,14 @@ void LSPServer::onDocumentSymbol(const DocumentSymbolParams &params,
   reply(std::move(symbols));
 }
 
+//===----------------------------------------------------------------------===//
+// Code Completion
+
+void LSPServer::onCompletion(const CompletionParams &params,
+                             Callback<CompletionList> reply) {
+  reply(server.getCodeCompletion(params.textDocument.uri, params.position));
+}
+
 //===----------------------------------------------------------------------===//
 // Entry Point
 //===----------------------------------------------------------------------===//
@@ -222,6 +245,10 @@ LogicalResult mlir::lsp::runPdllLSPServer(PDLLServer &server,
   messageHandler.method("textDocument/documentSymbol", &lspServer,
                         &LSPServer::onDocumentSymbol);
 
+  // Code Completion
+  messageHandler.method("textDocument/completion", &lspServer,
+                        &LSPServer::onCompletion);
+
   // Diagnostics
   lspServer.publishDiagnostics =
       messageHandler.outgoingNotification<PublishDiagnosticsParams>(

diff  --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
index cc8c22f49ab23..495df9019a174 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Tools/PDLL/ODS/Context.h"
 #include "mlir/Tools/PDLL/ODS/Dialect.h"
 #include "mlir/Tools/PDLL/ODS/Operation.h"
+#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
 #include "mlir/Tools/PDLL/Parser/Parser.h"
 #include "llvm/ADT/IntervalMap.h"
 #include "llvm/ADT/StringMap.h"
@@ -277,6 +278,13 @@ struct PDLDocument {
 
   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
 
+  //===--------------------------------------------------------------------===//
+  // Code Completion
+  //===--------------------------------------------------------------------===//
+
+  lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
+                                        const lsp::Position &completePos);
+
   //===--------------------------------------------------------------------===//
   // Fields
   //===--------------------------------------------------------------------===//
@@ -546,6 +554,279 @@ void PDLDocument::findDocumentSymbols(
   }
 }
 
+//===----------------------------------------------------------------------===//
+// PDLDocument: Code Completion
+//===----------------------------------------------------------------------===//
+
+namespace {
+class LSPCodeCompleteContext : public CodeCompleteContext {
+public:
+  LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
+                         ods::Context &odsContext)
+      : CodeCompleteContext(completeLoc), completionList(completionList),
+        odsContext(odsContext) {}
+
+  void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final {
+    ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes();
+    ArrayRef<StringRef> elementNames = tupleType.getElementNames();
+    for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
+      // Push back a completion item that uses the result index.
+      lsp::CompletionItem item;
+      item.label = llvm::formatv("{0} (field #{0})", i).str();
+      item.insertText = Twine(i).str();
+      item.filterText = item.sortText = item.insertText;
+      item.kind = lsp::CompletionItemKind::Field;
+      item.detail = llvm::formatv("{0}: {1}", i, elementTypes[i]);
+      item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+      completionList.items.emplace_back(item);
+
+      // If the element has a name, push back a completion item with that name.
+      if (!elementNames[i].empty()) {
+        item.label =
+            llvm::formatv("{1} (field #{0})", i, elementNames[i]).str();
+        item.filterText = item.label;
+        item.insertText = elementNames[i].str();
+        completionList.items.emplace_back(item);
+      }
+    }
+  }
+
+  void codeCompleteOperationMemberAccess(ast::OperationType opType) final {
+    Optional<StringRef> opName = opType.getName();
+    const ods::Operation *odsOp =
+        opName ? odsContext.lookupOperation(*opName) : nullptr;
+    if (!odsOp)
+      return;
+
+    ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
+    for (const auto &it : llvm::enumerate(results)) {
+      const ods::OperandOrResult &result = it.value();
+      const ods::TypeConstraint &constraint = result.getConstraint();
+
+      // Push back a completion item that uses the result index.
+      lsp::CompletionItem item;
+      item.label = llvm::formatv("{0} (field #{0})", it.index()).str();
+      item.insertText = Twine(it.index()).str();
+      item.filterText = item.sortText = item.insertText;
+      item.kind = lsp::CompletionItemKind::Field;
+      switch (result.getVariableLengthKind()) {
+      case ods::VariableLengthKind::Single:
+        item.detail = llvm::formatv("{0}: Value", it.index()).str();
+        break;
+      case ods::VariableLengthKind::Optional:
+        item.detail = llvm::formatv("{0}: Value?", it.index()).str();
+        break;
+      case ods::VariableLengthKind::Variadic:
+        item.detail = llvm::formatv("{0}: ValueRange", it.index()).str();
+        break;
+      }
+      item.documentation = lsp::MarkupContent{
+          lsp::MarkupKind::Markdown,
+          llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
+                        constraint.getCppClass())
+              .str()};
+      item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+      completionList.items.emplace_back(item);
+
+      // If the result has a name, push back a completion item with the result
+      // name.
+      if (!result.getName().empty()) {
+        item.label =
+            llvm::formatv("{1} (field #{0})", it.index(), result.getName())
+                .str();
+        item.filterText = item.label;
+        item.insertText = result.getName().str();
+        completionList.items.emplace_back(item);
+      }
+    }
+  }
+
+  void codeCompleteOperationAttributeName(StringRef opName) final {
+    const ods::Operation *odsOp = odsContext.lookupOperation(opName);
+    if (!odsOp)
+      return;
+
+    for (const ods::Attribute &attr : odsOp->getAttributes()) {
+      const ods::AttributeConstraint &constraint = attr.getConstraint();
+
+      lsp::CompletionItem item;
+      item.label = attr.getName().str();
+      item.kind = lsp::CompletionItemKind::Field;
+      item.detail = attr.isOptional() ? "optional" : "";
+      item.documentation = lsp::MarkupContent{
+          lsp::MarkupKind::Markdown,
+          llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
+                        constraint.getCppClass())
+              .str()};
+      item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+      completionList.items.emplace_back(item);
+    }
+  }
+
+  void codeCompleteConstraintName(ast::Type currentType,
+                                  bool allowNonCoreConstraints,
+                                  bool allowInlineTypeConstraints,
+                                  const ast::DeclScope *scope) final {
+    auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
+                                 StringRef snippetText = "") {
+      lsp::CompletionItem item;
+      item.label = constraint.str();
+      item.kind = lsp::CompletionItemKind::Class;
+      item.detail = (constraint + " constraint").str();
+      item.documentation = lsp::MarkupContent{
+          lsp::MarkupKind::Markdown,
+          ("A single entity core constraint of type `" + mlirType + "`").str()};
+      item.sortText = "0";
+      item.insertText = snippetText.str();
+      item.insertTextFormat = snippetText.empty()
+                                  ? lsp::InsertTextFormat::PlainText
+                                  : lsp::InsertTextFormat::Snippet;
+      completionList.items.emplace_back(item);
+    };
+
+    // Insert completions for the core constraints. Some core constraints have
+    // additional characteristics, so we may add then even if a type has been
+    // inferred.
+    if (!currentType) {
+      addCoreConstraint("Attr", "mlir::Attribute");
+      addCoreConstraint("Op", "mlir::Operation *");
+      addCoreConstraint("Value", "mlir::Value");
+      addCoreConstraint("ValueRange", "mlir::ValueRange");
+      addCoreConstraint("Type", "mlir::Type");
+      addCoreConstraint("TypeRange", "mlir::TypeRange");
+    }
+    if (allowInlineTypeConstraints) {
+      /// Attr<Type>.
+      if (!currentType || currentType.isa<ast::AttributeType>())
+        addCoreConstraint("Attr<type>", "mlir::Attribute", "Attr<$1>");
+      /// Value<Type>.
+      if (!currentType || currentType.isa<ast::ValueType>())
+        addCoreConstraint("Value<type>", "mlir::Value", "Value<$1>");
+      /// ValueRange<TypeRange>.
+      if (!currentType || currentType.isa<ast::ValueRangeType>())
+        addCoreConstraint("ValueRange<type>", "mlir::ValueRange",
+                          "ValueRange<$1>");
+    }
+
+    // If a scope was provided, check it for potential constraints.
+    while (scope) {
+      for (const ast::Decl *decl : scope->getDecls()) {
+        if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
+          if (!allowNonCoreConstraints)
+            continue;
+
+          lsp::CompletionItem item;
+          item.label = cst->getName().getName().str();
+          item.kind = lsp::CompletionItemKind::Interface;
+          item.sortText = "2_" + item.label;
+
+          // Skip constraints that are not single-arg. We currently only
+          // complete variable constraints.
+          if (cst->getInputs().size() != 1)
+            continue;
+
+          // Ensure the input type matched the given type.
+          ast::Type constraintType = cst->getInputs()[0]->getType();
+          if (currentType && !currentType.refineWith(constraintType))
+            continue;
+
+          // Format the constraint signature.
+          {
+            llvm::raw_string_ostream strOS(item.detail);
+            strOS << "(";
+            llvm::interleaveComma(
+                cst->getInputs(), strOS, [&](const ast::VariableDecl *var) {
+                  strOS << var->getName().getName() << ": " << var->getType();
+                });
+            strOS << ") -> " << cst->getResultType();
+          }
+
+          completionList.items.emplace_back(item);
+        }
+      }
+
+      scope = scope->getParentScope();
+    }
+  }
+
+  void codeCompleteDialectName() final {
+    // Code complete known dialects.
+    for (const ods::Dialect &dialect : odsContext.getDialects()) {
+      lsp::CompletionItem item;
+      item.label = dialect.getName().str();
+      item.kind = lsp::CompletionItemKind::Class;
+      item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+      completionList.items.emplace_back(item);
+    }
+  }
+
+  void codeCompleteOperationName(StringRef dialectName) final {
+    const ods::Dialect *dialect = odsContext.lookupDialect(dialectName);
+    if (!dialect)
+      return;
+
+    for (const auto &it : dialect->getOperations()) {
+      const ods::Operation &op = *it.second;
+
+      lsp::CompletionItem item;
+      item.label = op.getName().drop_front(dialectName.size() + 1).str();
+      item.kind = lsp::CompletionItemKind::Field;
+      item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+      completionList.items.emplace_back(item);
+    }
+  }
+
+  void codeCompletePatternMetadata() final {
+    auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
+                                   StringRef snippetText = "") {
+      lsp::CompletionItem item;
+      item.label = constraint.str();
+      item.kind = lsp::CompletionItemKind::Class;
+      item.detail = "pattern metadata";
+      item.documentation =
+          lsp::MarkupContent{lsp::MarkupKind::Markdown, desc.str()};
+      item.insertText = snippetText.str();
+      item.insertTextFormat = snippetText.empty()
+                                  ? lsp::InsertTextFormat::PlainText
+                                  : lsp::InsertTextFormat::Snippet;
+      completionList.items.emplace_back(item);
+    };
+
+    addSimpleConstraint("benefit", "The `benefit` of matching the pattern.",
+                        "benefit($1)");
+    addSimpleConstraint("recursion",
+                        "The pattern properly handles recursive application.");
+  }
+
+private:
+  lsp::CompletionList &completionList;
+  ods::Context &odsContext;
+};
+} // namespace
+
+lsp::CompletionList
+PDLDocument::getCodeCompletion(const lsp::URIForFile &uri,
+                               const lsp::Position &completePos) {
+  SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
+  if (!posLoc.isValid())
+    return lsp::CompletionList();
+
+  // Adjust the position one further to after the completion trigger token.
+  posLoc = SMLoc::getFromPointer(posLoc.getPointer() + 1);
+
+  // To perform code completion, we run another parse of the module with the
+  // code completion context provided.
+  ods::Context tmpODSContext;
+  lsp::CompletionList completionList;
+  LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
+                                            tmpODSContext);
+
+  ast::Context tmpContext(tmpODSContext);
+  (void)parsePDLAST(tmpContext, sourceMgr, &lspCompleteContext);
+
+  return completionList;
+}
+
 //===----------------------------------------------------------------------===//
 // PDLTextFileChunk
 //===----------------------------------------------------------------------===//
@@ -600,6 +881,8 @@ class PDLTextFile {
   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
                                  lsp::Position hoverPos);
   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
+  lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
+                                        lsp::Position completePos);
 
 private:
   /// Find the PDL document that contains the given position, and update the
@@ -737,6 +1020,22 @@ void PDLTextFile::findDocumentSymbols(
   }
 }
 
+lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri,
+                                                   lsp::Position completePos) {
+  PDLTextFileChunk &chunk = getChunkFor(completePos);
+  lsp::CompletionList completionList =
+      chunk.document.getCodeCompletion(uri, completePos);
+
+  // Adjust any completion locations.
+  for (lsp::CompletionItem &item : completionList.items) {
+    if (item.textEdit)
+      chunk.adjustLocForChunkOffset(item.textEdit->range);
+    for (lsp::TextEdit &edit : item.additionalTextEdits)
+      chunk.adjustLocForChunkOffset(edit.range);
+  }
+  return completionList;
+}
+
 PDLTextFileChunk &PDLTextFile::getChunkFor(lsp::Position &pos) {
   if (chunks.size() == 1)
     return *chunks.front();
@@ -815,3 +1114,12 @@ void lsp::PDLLServer::findDocumentSymbols(
   if (fileIt != impl->files.end())
     fileIt->second->findDocumentSymbols(symbols);
 }
+
+lsp::CompletionList
+lsp::PDLLServer::getCodeCompletion(const URIForFile &uri,
+                                   const Position &completePos) {
+  auto fileIt = impl->files.find(uri.file());
+  if (fileIt != impl->files.end())
+    return fileIt->second->getCodeCompletion(uri, completePos);
+  return CompletionList();
+}

diff  --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h
index 1a647f18db125..5f593204bab24 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h
@@ -15,6 +15,7 @@
 namespace mlir {
 namespace lsp {
 struct Diagnostic;
+struct CompletionList;
 struct DocumentSymbol;
 struct Hover;
 struct Location;
@@ -57,6 +58,10 @@ class PDLLServer {
   void findDocumentSymbols(const URIForFile &uri,
                            std::vector<DocumentSymbol> &symbols);
 
+  /// Get the code completion list for the position within the given file.
+  CompletionList getCodeCompletion(const URIForFile &uri,
+                                   const Position &completePos);
+
 private:
   struct Impl;
 

diff  --git a/mlir/test/mlir-pdll-lsp-server/completion.test b/mlir/test/mlir-pdll-lsp-server/completion.test
new file mode 100644
index 0000000000000..a8d80d8f5ff59
--- /dev/null
+++ b/mlir/test/mlir-pdll-lsp-server/completion.test
@@ -0,0 +1,205 @@
+// RUN: mlir-pdll-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s
+{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"pdll","capabilities":{},"trace":"off"}}
+// -----
+{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
+  "uri":"test:///foo.pdll",
+  "languageId":"pdll",
+  "version":1,
+  "text":"Constraint ValueCst(value: Value);\nConstraint Cst();\nPattern FooPattern with benefit(1) {\nlet tuple = (value1 = _: Op, _: Op<test.op>);\nerase tuple.value1;\n}"
+}}}
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+  "textDocument":{"uri":"test:///foo.pdll"},
+  "position":{"line":4,"character":12}
+}}
+//      CHECK:  "id": 1
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "isIncomplete": false,
+// CHECK-NEXT:    "items": [
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "0: Op",
+// CHECK-NEXT:        "filterText": "0",
+// CHECK-NEXT:        "insertText": "0",
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 5,
+// CHECK-NEXT:        "label": "0 (field #0)",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "0: Op",
+// CHECK-NEXT:        "filterText": "value1 (field #0)",
+// CHECK-NEXT:        "insertText": "value1",
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 5,
+// CHECK-NEXT:        "label": "value1 (field #0)",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "1: Op<test.op>",
+// CHECK-NEXT:        "filterText": "1",
+// CHECK-NEXT:        "insertText": "1",
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 5,
+// CHECK-NEXT:        "label": "1 (field #1)",
+// CHECK-NEXT:        "sortText": "1"
+// CHECK-NEXT:      }
+// CHECK-NEXT:    ]
+// CHECK-NEXT:  }
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+  "textDocument":{"uri":"test:///foo.pdll"},
+  "position":{"line":2,"character":23}
+}}
+//      CHECK:  "id": 1
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "isIncomplete": false,
+// CHECK-NEXT:    "items": [
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "pattern metadata",
+// CHECK-NEXT:        "documentation": {
+// CHECK-NEXT:          "kind": "markdown",
+// CHECK-NEXT:          "value": "The `benefit` of matching the pattern."
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "insertText": "benefit($1)",
+// CHECK-NEXT:        "insertTextFormat": 2,
+// CHECK-NEXT:        "kind": 7,
+// CHECK-NEXT:        "label": "benefit"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "pattern metadata",
+// CHECK-NEXT:        "documentation": {
+// CHECK-NEXT:          "kind": "markdown",
+// CHECK-NEXT:          "value": "The pattern properly handles recursive application."
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 7,
+// CHECK-NEXT:        "label": "recursion"
+// CHECK-NEXT:      }
+// CHECK-NEXT:    ]
+// CHECK-NEXT:  }
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
+  "textDocument":{"uri":"test:///foo.pdll"},
+  "position":{"line":3,"character":24}
+}}
+//      CHECK:  "id": 1
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "isIncomplete": false,
+// CHECK-NEXT:    "items": [
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "Attr constraint",
+// CHECK-NEXT:        "documentation": {
+// CHECK-NEXT:          "kind": "markdown",
+// CHECK-NEXT:          "value": "A single entity core constraint of type `mlir::Attribute`"
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 7,
+// CHECK-NEXT:        "label": "Attr",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "Op constraint",
+// CHECK-NEXT:        "documentation": {
+// CHECK-NEXT:          "kind": "markdown",
+// CHECK-NEXT:          "value": "A single entity core constraint of type `mlir::Operation *`"
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 7,
+// CHECK-NEXT:        "label": "Op",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "Value constraint",
+// CHECK-NEXT:        "documentation": {
+// CHECK-NEXT:          "kind": "markdown",
+// CHECK-NEXT:          "value": "A single entity core constraint of type `mlir::Value`"
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 7,
+// CHECK-NEXT:        "label": "Value",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "ValueRange constraint",
+// CHECK-NEXT:        "documentation": {
+// CHECK-NEXT:          "kind": "markdown",
+// CHECK-NEXT:          "value": "A single entity core constraint of type `mlir::ValueRange`"
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 7,
+// CHECK-NEXT:        "label": "ValueRange",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "Type constraint",
+// CHECK-NEXT:        "documentation": {
+// CHECK-NEXT:          "kind": "markdown",
+// CHECK-NEXT:          "value": "A single entity core constraint of type `mlir::Type`"
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 7,
+// CHECK-NEXT:        "label": "Type",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "TypeRange constraint",
+// CHECK-NEXT:        "documentation": {
+// CHECK-NEXT:          "kind": "markdown",
+// CHECK-NEXT:          "value": "A single entity core constraint of type `mlir::TypeRange`"
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "insertTextFormat": 1,
+// CHECK-NEXT:        "kind": 7,
+// CHECK-NEXT:        "label": "TypeRange",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "Attr<type> constraint",
+// CHECK-NEXT:        "documentation": {
+// CHECK-NEXT:          "kind": "markdown",
+// CHECK-NEXT:          "value": "A single entity core constraint of type `mlir::Attribute`"
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "insertText": "Attr<$1>",
+// CHECK-NEXT:        "insertTextFormat": 2,
+// CHECK-NEXT:        "kind": 7,
+// CHECK-NEXT:        "label": "Attr<type>",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "Value<type> constraint",
+// CHECK-NEXT:        "documentation": {
+// CHECK-NEXT:          "kind": "markdown",
+// CHECK-NEXT:          "value": "A single entity core constraint of type `mlir::Value`"
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "insertText": "Value<$1>",
+// CHECK-NEXT:        "insertTextFormat": 2,
+// CHECK-NEXT:        "kind": 7,
+// CHECK-NEXT:        "label": "Value<type>",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "ValueRange<type> constraint",
+// CHECK-NEXT:        "documentation": {
+// CHECK-NEXT:          "kind": "markdown",
+// CHECK-NEXT:          "value": "A single entity core constraint of type `mlir::ValueRange`"
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "insertText": "ValueRange<$1>",
+// CHECK-NEXT:        "insertTextFormat": 2,
+// CHECK-NEXT:        "kind": 7,
+// CHECK-NEXT:        "label": "ValueRange<type>",
+// CHECK-NEXT:        "sortText": "0"
+// CHECK-NEXT:      },
+// CHECK-NEXT:      {
+// CHECK-NEXT:        "detail": "(value: Value) -> Tuple<>",
+// CHECK-NEXT:        "kind": 8,
+// CHECK-NEXT:        "label": "ValueCst",
+// CHECK-NEXT:        "sortText": "2_ValueCst"
+// CHECK-NEXT:      }
+// CHECK-NEXT:    ]
+// CHECK-NEXT:  }
+// -----
+{"jsonrpc":"2.0","id":3,"method":"shutdown"}
+// -----
+{"jsonrpc":"2.0","method":"exit"}

diff  --git a/mlir/test/mlir-pdll-lsp-server/initialize-params.test b/mlir/test/mlir-pdll-lsp-server/initialize-params.test
index d2af6f514fe54..2d40edfaa7bc5 100644
--- a/mlir/test/mlir-pdll-lsp-server/initialize-params.test
+++ b/mlir/test/mlir-pdll-lsp-server/initialize-params.test
@@ -5,6 +5,13 @@
 // CHECK-NEXT:  "jsonrpc": "2.0",
 // CHECK-NEXT:  "result": {
 // CHECK-NEXT:    "capabilities": {
+// CHECK-NEXT:      "completionProvider": {
+// CHECK-NEXT:        "allCommitCharacters": [
+// CHECK:             ],
+// CHECK-NEXT:        "resolveProvider": false,
+// CHECK-NEXT:        "triggerCharacters": [ 
+// CHECK:             ]
+// CHECK-NEXT:      },
 // CHECK-NEXT:      "definitionProvider": true,
 // CHECK-NEXT:      "documentSymbolProvider": true,
 // CHECK-NEXT:      "hoverProvider": true,


        


More information about the Mlir-commits mailing list