[Mlir-commits] [mlir] bf352e0 - [mlir:PDLL] Add better support for providing Constraint/Pattern/Rewrite documentation

River Riddle llvmlistbot at llvm.org
Thu Jun 2 16:31:20 PDT 2022


Author: River Riddle
Date: 2022-06-02T16:31:07-07:00
New Revision: bf352e0b2ef9f8824a5b88d44313b5a13258350d

URL: https://github.com/llvm/llvm-project/commit/bf352e0b2ef9f8824a5b88d44313b5a13258350d
DIFF: https://github.com/llvm/llvm-project/commit/bf352e0b2ef9f8824a5b88d44313b5a13258350d.diff

LOG: [mlir:PDLL] Add better support for providing Constraint/Pattern/Rewrite documentation

This commit enables providing long-form documentation more seamlessly to the LSP
by revamping decl documentation. For ODS imported constructs, we now also import
descriptions and attach them to decls when possible. For PDLL constructs, the LSP will
now try to provide documentation by parsing the comments directly above the decls
location within the source file. This commit also adds a new parser flag
`enableDocumentation` that gates the import and attachment of ODS documentation,
which is unnecessary in the normal build process (i.e. it should only be used/consumed
by tools).

Differential Revision: https://reviews.llvm.org/D124881

Added: 
    

Modified: 
    mlir/include/mlir/TableGen/Attribute.h
    mlir/include/mlir/TableGen/Constraint.h
    mlir/include/mlir/TableGen/Type.h
    mlir/include/mlir/Tools/PDLL/AST/Nodes.h
    mlir/include/mlir/Tools/PDLL/Parser/Parser.h
    mlir/lib/TableGen/Attribute.cpp
    mlir/lib/TableGen/Constraint.cpp
    mlir/lib/TableGen/Type.cpp
    mlir/lib/Tools/PDLL/AST/Nodes.cpp
    mlir/lib/Tools/PDLL/ODS/Operation.cpp
    mlir/lib/Tools/PDLL/Parser/Parser.cpp
    mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
    mlir/test/mlir-pdll-lsp-server/hover.test
    mlir/tools/mlir-pdll/mlir-pdll.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index 2c9732ea88f12..da1af3ad51331 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -113,9 +113,6 @@ class Attribute : public AttrConstraint {
 
   // Returns the dialect for the attribute if defined.
   Dialect getDialect() const;
-
-  // Returns the description of the attribute.
-  StringRef getDescription() const;
 };
 
 // Wrapper class providing helper methods for accessing MLIR constant attribute

diff  --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h
index 0c74f89189e92..71bb7d20926a7 100644
--- a/mlir/include/mlir/TableGen/Constraint.h
+++ b/mlir/include/mlir/TableGen/Constraint.h
@@ -50,10 +50,14 @@ class Constraint {
   // mlir::Attribute.
   std::string getConditionTemplate() const;
 
-  // Returns the user-readable description of this constraint. If the
-  // description is not provided, returns the TableGen def name.
+  // Returns the user-readable summary of this constraint. If the summary is not
+  // provided, returns the TableGen def name.
   StringRef getSummary() const;
 
+  // Returns the long-form description of this constraint. If the description is
+  // not provided, returns an empty string.
+  StringRef getDescription() const;
+
   /// Returns the name of the TablGen def of this constraint. In some cases
   /// where the current def is anonymous, the name of the base def is used (e.g.
   /// `Optional<>`/`Variadic<>` type constraints).

diff  --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h
index 1828e88461b32..7cc263faa4c21 100644
--- a/mlir/include/mlir/TableGen/Type.h
+++ b/mlir/include/mlir/TableGen/Type.h
@@ -65,9 +65,6 @@ class Type : public TypeConstraint {
 public:
   explicit Type(const llvm::Record *record);
 
-  // Returns the description of the type.
-  StringRef getDescription() const;
-
   // Returns the dialect for the type if defined.
   Dialect getDialect() const;
 };

diff  --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
index ab1a53f90b8fc..2e947b1106adc 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
@@ -635,6 +635,13 @@ class Decl : public Node {
   /// Provide type casting support.
   static bool classof(const Node *node);
 
+  /// Set the documentation comment for this decl.
+  void setDocComment(Context &ctx, StringRef comment);
+
+  /// Return the documentation comment attached to this decl if it has been set.
+  /// Otherwise, returns None.
+  Optional<StringRef> getDocComment() const { return docComment; }
+
 protected:
   Decl(TypeID typeID, SMRange loc, const Name *name = nullptr)
       : Node(typeID, loc), name(name) {}
@@ -643,6 +650,10 @@ class Decl : public Node {
   /// The name of the decl. This is optional for some decls, such as
   /// PatternDecl.
   const Name *name;
+
+  /// The documentation comment attached to this decl. Defaults to None if
+  /// the comment is unset/unknown.
+  Optional<StringRef> docComment;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Tools/PDLL/Parser/Parser.h b/mlir/include/mlir/Tools/PDLL/Parser/Parser.h
index ce5815a478d1d..1a43a3bbfd48c 100644
--- a/mlir/include/mlir/Tools/PDLL/Parser/Parser.h
+++ b/mlir/include/mlir/Tools/PDLL/Parser/Parser.h
@@ -26,12 +26,16 @@ class Context;
 class Module;
 } // namespace ast
 
-/// Parse an AST module from the main file of the given source manager. An
-/// optional code completion context may be provided to receive code completion
-/// suggestions. If a completion is hit, this method returns a failure.
+/// Parse an AST module from the main file of the given source manager.
+/// `enableDocumentation` is an optional flag that, when set, indicates that the
+/// parser should also include documentation when building the AST when
+/// possible. `codeCompleteContext` is an optional code completion context that
+/// may be provided to receive code completion suggestions. If a completion is
+/// hit, this method returns a failure.
 FailureOr<ast::Module *>
-parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
-            CodeCompleteContext *codeCompleteContext = nullptr);
+parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
+             bool enableDocumentation = false,
+             CodeCompleteContext *codeCompleteContext = nullptr);
 } // namespace pdll
 } // namespace mlir
 

diff  --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index 9254e4e983c57..80a5eae12bfdb 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -132,10 +132,6 @@ Dialect Attribute::getDialect() const {
   return Dialect(nullptr);
 }
 
-StringRef Attribute::getDescription() const {
-  return def->getValueAsString("description");
-}
-
 ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
   assert(def->isSubClassOf("ConstantAttr") &&
          "must be subclass of TableGen 'ConstantAttr' class");

diff  --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp
index 8e62120ad8374..bc7784c53037c 100644
--- a/mlir/lib/TableGen/Constraint.cpp
+++ b/mlir/lib/TableGen/Constraint.cpp
@@ -57,6 +57,10 @@ StringRef Constraint::getSummary() const {
   return def->getName();
 }
 
+StringRef Constraint::getDescription() const {
+  return def->getValueAsOptionalString("description").getValueOr("");
+}
+
 StringRef Constraint::getDefName() const {
   if (Optional<StringRef> baseDefName = getBaseDefName())
     return *baseDefName;

diff  --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index 601440ebc6435..115fb3ba41dc2 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -76,10 +76,6 @@ std::string TypeConstraint::getCPPClassName() const {
 
 Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
 
-StringRef Type::getDescription() const {
-  return def->getValueAsString("description");
-}
-
 Dialect Type::getDialect() const {
   return Dialect(def->getValueAsDef("dialect"));
 }

diff  --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
index 417483444615c..3af6ddf4dba31 100644
--- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
@@ -355,6 +355,14 @@ TypeExpr *TypeExpr::create(Context &ctx, SMRange loc, StringRef value) {
       TypeExpr(ctx, loc, copyStringWithNull(ctx, value));
 }
 
+//===----------------------------------------------------------------------===//
+// Decl
+//===----------------------------------------------------------------------===//
+
+void Decl::setDocComment(Context &ctx, StringRef comment) {
+  docComment = comment.copy(ctx.getAllocator());
+}
+
 //===----------------------------------------------------------------------===//
 // AttrConstraintDecl
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Tools/PDLL/ODS/Operation.cpp b/mlir/lib/Tools/PDLL/ODS/Operation.cpp
index c991c33be7d22..7e708be1ae4d1 100644
--- a/mlir/lib/Tools/PDLL/ODS/Operation.cpp
+++ b/mlir/lib/Tools/PDLL/ODS/Operation.cpp
@@ -20,10 +20,7 @@ using namespace mlir::pdll::ods;
 Operation::Operation(StringRef name, StringRef summary, StringRef desc,
                      StringRef nativeClassName, bool supportsTypeInferrence,
                      llvm::SMLoc loc)
-    : name(name.str()), summary(summary.str()),
+    : name(name.str()), summary(summary.str()), description(desc.str()),
       nativeClassName(nativeClassName.str()),
       supportsTypeInferrence(supportsTypeInferrence),
-      location(loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)) {
-  llvm::raw_string_ostream descOS(description);
-  raw_indented_ostream(descOS).printReindented(desc.rtrim(" \t"));
-}
+      location(loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)) {}

diff  --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 9cd3d8caf40bf..4b7fd85227aa0 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Tools/PDLL/Parser/Parser.h"
 #include "Lexer.h"
+#include "mlir/Support/IndentedOstream.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/TableGen/Argument.h"
 #include "mlir/TableGen/Attribute.h"
@@ -43,9 +44,10 @@ namespace {
 class Parser {
 public:
   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
-         CodeCompleteContext *codeCompleteContext)
+         bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
-        curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)),
+        curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
+        valueTy(ast::ValueType::get(ctx)),
         valueRangeTy(ast::ValueRangeType::get(ctx)),
         typeTy(ast::TypeType::get(ctx)),
         typeRangeTy(ast::TypeRangeType::get(ctx)),
@@ -125,6 +127,27 @@ class Parser {
     return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
   }
 
+  /// Process the given documentation string, or return an empty string if
+  /// documentation isn't enabled.
+  StringRef processDoc(StringRef doc) {
+    return enableDocumentation ? doc : StringRef();
+  }
+
+  /// Process the given documentation string and format it, or return an empty
+  /// string if documentation isn't enabled.
+  std::string processAndFormatDoc(const Twine &doc) {
+    if (!enableDocumentation)
+      return "";
+    std::string docStr;
+    {
+      llvm::raw_string_ostream docOS(docStr);
+      std::string tmpDocStr = doc.str();
+      raw_indented_ostream(docOS).printReindented(
+          StringRef(tmpDocStr).rtrim(" \t"));
+    }
+    return docStr;
+  }
+
   //===--------------------------------------------------------------------===//
   // Directives
 
@@ -140,10 +163,10 @@ class Parser {
   /// Create a user defined native constraint for a constraint imported from
   /// ODS.
   template <typename ConstraintT>
-  ast::Decl *createODSNativePDLLConstraintDecl(StringRef name,
-                                               StringRef codeBlock, SMRange loc,
-                                               ast::Type type,
-                                               StringRef nativeType);
+  ast::Decl *
+  createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
+                                    SMRange loc, ast::Type type,
+                                    StringRef nativeType, StringRef docString);
   template <typename ConstraintT>
   ast::Decl *
   createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
@@ -520,6 +543,10 @@ class Parser {
   /// The current token within the lexer.
   Token curToken;
 
+  /// A flag indicating if the parser should add documentation to AST nodes when
+  /// viable.
+  bool enableDocumentation;
+
   /// The most recently defined decl scope.
   ast::DeclScope *curDeclScope = nullptr;
   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
@@ -801,9 +828,10 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
   ods::Context &odsContext = ctx.getODSContext();
   auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
       -> const ods::TypeConstraint & {
-    return odsContext.insertTypeConstraint(cst.constraint.getUniqueDefName(),
-                                           cst.constraint.getSummary(),
-                                           cst.constraint.getCPPClassName());
+    return odsContext.insertTypeConstraint(
+        cst.constraint.getUniqueDefName(),
+        processDoc(cst.constraint.getSummary()),
+        cst.constraint.getCPPClassName());
   };
   auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
     return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
@@ -821,20 +849,20 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
     bool inserted = false;
     ods::Operation *odsOp = nullptr;
     std::tie(odsOp, inserted) = odsContext.insertOperation(
-        op.getOperationName(), op.getSummary(), op.getDescription(),
-        op.getQualCppClassName(), supportsResultTypeInferrence,
-        op.getLoc().front());
+        op.getOperationName(), processDoc(op.getSummary()),
+        processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
+        supportsResultTypeInferrence, op.getLoc().front());
 
     // Ignore operations that have already been added.
     if (!inserted)
       continue;
 
     for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
-      odsOp->appendAttribute(
-          attr.name, attr.attr.isOptional(),
-          odsContext.insertAttributeConstraint(attr.attr.getUniqueDefName(),
-                                               attr.attr.getSummary(),
-                                               attr.attr.getStorageType()));
+      odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
+                             odsContext.insertAttributeConstraint(
+                                 attr.attr.getUniqueDefName(),
+                                 processDoc(attr.attr.getSummary()),
+                                 attr.attr.getStorageType()));
     }
     for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
       odsOp->appendOperand(operand.name, getLengthKind(operand),
@@ -883,26 +911,27 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
                       cppClassName)
             .str();
 
+    std::string desc =
+        processAndFormatDoc(def->getValueAsString("description"));
     if (def->isSubClassOf("OpInterface")) {
       decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
-          name, codeBlock, loc, opTy, cppClassName));
+          name, codeBlock, loc, opTy, cppClassName, desc));
     } else if (def->isSubClassOf("AttrInterface")) {
       decls.push_back(
           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
-              name, codeBlock, loc, attrTy, cppClassName));
+              name, codeBlock, loc, attrTy, cppClassName, desc));
     } else if (def->isSubClassOf("TypeInterface")) {
       decls.push_back(
           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
-              name, codeBlock, loc, typeTy, cppClassName));
+              name, codeBlock, loc, typeTy, cppClassName, desc));
     }
   }
 }
 
 template <typename ConstraintT>
-ast::Decl *
-Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
-                                          SMRange loc, ast::Type type,
-                                          StringRef nativeType) {
+ast::Decl *Parser::createODSNativePDLLConstraintDecl(
+    StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
+    StringRef nativeType, StringRef docString) {
   // Build the single input parameter.
   ast::DeclScope *argScope = pushDeclScope();
   auto *paramVar = ast::VariableDecl::create(
@@ -915,6 +944,7 @@ Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
   auto *constraintDecl = ast::UserConstraintDecl::createNative(
       ctx, ast::Name::create(ctx, name, loc), paramVar,
       /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx), nativeType);
+  constraintDecl->setDocComment(ctx, docString);
   curDeclScope->add(constraintDecl);
   return constraintDecl;
 }
@@ -931,8 +961,20 @@ Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
       "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
       &fmtContext);
 
+  // If documentation was enabled, build the doc string for the generated
+  // constraint. It would be nice to do this lazily, but TableGen information is
+  // destroyed after we finish parsing the file.
+  std::string docString;
+  if (enableDocumentation) {
+    StringRef desc = constraint.getDescription();
+    docString = processAndFormatDoc(
+        constraint.getSummary() +
+        (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
+  }
+
   return createODSNativePDLLConstraintDecl<ConstraintT>(
-      constraint.getUniqueDefName(), codeBlock, loc, type, nativeType);
+      constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
+      docString);
 }
 
 //===----------------------------------------------------------------------===//
@@ -3080,8 +3122,9 @@ void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName,
 //===----------------------------------------------------------------------===//
 
 FailureOr<ast::Module *>
-mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
-                        CodeCompleteContext *codeCompleteContext) {
-  Parser parser(ctx, sourceMgr, codeCompleteContext);
+mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
+                         bool enableDocumentation,
+                         CodeCompleteContext *codeCompleteContext) {
+  Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
   return parser.parseModule();
 }

diff  --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
index cbce9f8ab0cd5..c4cb543593729 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
@@ -110,6 +110,57 @@ getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag,
   return lspDiag;
 }
 
+/// Get or extract the documentation for the given decl.
+static Optional<std::string> getDocumentationFor(llvm::SourceMgr &sourceMgr,
+                                                 const ast::Decl *decl) {
+  // If the decl already had documentation set, use it.
+  if (Optional<StringRef> doc = decl->getDocComment())
+    return doc->str();
+
+  // If the decl doesn't yet have documentation, try to extract it from the
+  // source file. This is a heuristic, and isn't intended to cover every case,
+  // but should cover the most common. We essentially look for a comment
+  // preceding the decl, and if we find one, use that as the documentation.
+  SMLoc startLoc = decl->getLoc().Start;
+  if (!startLoc.isValid())
+    return llvm::None;
+  int bufferId = sourceMgr.FindBufferContainingLoc(startLoc);
+  if (bufferId == 0)
+    return llvm::None;
+  const char *bufferStart =
+      sourceMgr.getMemoryBuffer(bufferId)->getBufferStart();
+  StringRef buffer(bufferStart, startLoc.getPointer() - bufferStart);
+
+  // Pop the last line from the buffer string.
+  auto popLastLine = [&]() -> Optional<StringRef> {
+    size_t newlineOffset = buffer.find_last_of("\n");
+    if (newlineOffset == StringRef::npos)
+      return llvm::None;
+    StringRef lastLine = buffer.drop_front(newlineOffset).trim();
+    buffer = buffer.take_front(newlineOffset);
+    return lastLine;
+  };
+
+  // Try to pop the current line, which contains the decl.
+  if (!popLastLine())
+    return llvm::None;
+
+  // Try to parse a comment string from the source file.
+  SmallVector<StringRef> commentLines;
+  while (Optional<StringRef> line = popLastLine()) {
+    // Check for a comment at the beginning of the line.
+    if (!line->startswith("//"))
+      break;
+
+    // Extract the document string from the comment.
+    commentLines.push_back(line->drop_while([](char c) { return c == '/'; }));
+  }
+
+  if (commentLines.empty())
+    return llvm::None;
+  return llvm::join(llvm::reverse(commentLines), "\n");
+}
+
 //===----------------------------------------------------------------------===//
 // PDLIndex
 //===----------------------------------------------------------------------===//
@@ -278,7 +329,7 @@ struct PDLDocument {
                                  const SMRange &hoverRange);
   lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl,
                                    const SMRange &hoverRange);
-  lsp::Hover buildHoverForPattern(const ast::PatternDecl *patternDecl,
+  lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl,
                                   const SMRange &hoverRange);
   lsp::Hover buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
                                          const SMRange &hoverRange);
@@ -361,7 +412,7 @@ PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents,
     if (auto lspDiag = getLspDiagnoticFromDiag(sourceMgr, diag, uri))
       diagnostics.push_back(std::move(*lspDiag));
   });
-  astModule = parsePDLAST(astContext, sourceMgr);
+  astModule = parsePDLLAST(astContext, sourceMgr, /*enableDocumentation=*/true);
 
   // Initialize the set of parsed includes.
   lsp::gatherIncludeFiles(sourceMgr, parsedIncludes);
@@ -486,23 +537,25 @@ lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl,
   return hover;
 }
 
-lsp::Hover
-PDLDocument::buildHoverForPattern(const ast::PatternDecl *patternDecl,
-                                  const SMRange &hoverRange) {
+lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl,
+                                             const SMRange &hoverRange) {
   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
   {
     llvm::raw_string_ostream hoverOS(hover.contents.value);
     hoverOS << "**Pattern**";
-    if (const ast::Name *name = patternDecl->getName())
+    if (const ast::Name *name = decl->getName())
       hoverOS << ": `" << name->getName() << "`";
     hoverOS << "\n***\n";
-    if (Optional<uint16_t> benefit = patternDecl->getBenefit())
+    if (Optional<uint16_t> benefit = decl->getBenefit())
       hoverOS << "Benefit: " << *benefit << "\n";
-    if (patternDecl->hasBoundedRewriteRecursion())
+    if (decl->hasBoundedRewriteRecursion())
       hoverOS << "HasBoundedRewriteRecursion\n";
     hoverOS << "RootOp: `"
-            << patternDecl->getRootRewriteStmt()->getRootOpExpr()->getType()
-            << "`\n";
+            << decl->getRootRewriteStmt()->getRootOpExpr()->getType() << "`\n";
+
+    // Format the documentation for the decl.
+    if (Optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
+      hoverOS << "\n" << *doc << "\n";
   }
   return hover;
 }
@@ -552,20 +605,24 @@ lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
     }
     ast::Type resultType = decl->getResultType();
     if (auto resultTupleTy = resultType.dyn_cast<ast::TupleType>()) {
-      if (resultTupleTy.empty())
-        return hover;
-
-      hoverOS << "Results:\n";
-      for (auto it : llvm::zip(resultTupleTy.getElementNames(),
-                               resultTupleTy.getElementTypes())) {
-        StringRef name = std::get<0>(it);
-        hoverOS << "* " << (name.empty() ? "" : (name + ": ")) << "`"
-                << std::get<1>(it) << "`\n";
+      if (!resultTupleTy.empty()) {
+        hoverOS << "Results:\n";
+        for (auto it : llvm::zip(resultTupleTy.getElementNames(),
+                                 resultTupleTy.getElementTypes())) {
+          StringRef name = std::get<0>(it);
+          hoverOS << "* " << (name.empty() ? "" : (name + ": ")) << "`"
+                  << std::get<1>(it) << "`\n";
+        }
+        hoverOS << "***\n";
       }
     } else {
       hoverOS << "Results:\n* `" << resultType << "`\n";
+      hoverOS << "***\n";
     }
-    hoverOS << "***\n";
+
+    // Format the documentation for the decl.
+    if (Optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
+      hoverOS << "\n" << *doc << "\n";
   }
   return hover;
 }
@@ -619,11 +676,13 @@ void PDLDocument::findDocumentSymbols(
 namespace {
 class LSPCodeCompleteContext : public CodeCompleteContext {
 public:
-  LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
+  LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
+                         lsp::CompletionList &completionList,
                          ods::Context &odsContext,
                          ArrayRef<std::string> includeDirs)
-      : CodeCompleteContext(completeLoc), completionList(completionList),
-        odsContext(odsContext), includeDirs(includeDirs) {}
+      : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
+        completionList(completionList), odsContext(odsContext),
+        includeDirs(includeDirs) {}
 
   void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final {
     ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes();
@@ -798,6 +857,12 @@ class LSPCodeCompleteContext : public CodeCompleteContext {
             strOS << ") -> " << cst->getResultType();
           }
 
+          // Format the documentation for the constraint.
+          if (Optional<std::string> doc = getDocumentationFor(sourceMgr, cst)) {
+            item.documentation =
+                lsp::MarkupContent{lsp::MarkupKind::Markdown, std::move(*doc)};
+          }
+
           completionList.items.emplace_back(item);
         }
       }
@@ -921,6 +986,7 @@ class LSPCodeCompleteContext : public CodeCompleteContext {
   }
 
 private:
+  llvm::SourceMgr &sourceMgr;
   lsp::CompletionList &completionList;
   ods::Context &odsContext;
   ArrayRef<std::string> includeDirs;
@@ -938,11 +1004,13 @@ PDLDocument::getCodeCompletion(const lsp::URIForFile &uri,
   // code completion context provided.
   ods::Context tmpODSContext;
   lsp::CompletionList completionList;
-  LSPCodeCompleteContext lspCompleteContext(
-      posLoc, completionList, tmpODSContext, sourceMgr.getIncludeDirs());
+  LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
+                                            tmpODSContext,
+                                            sourceMgr.getIncludeDirs());
 
   ast::Context tmpContext(tmpODSContext);
-  (void)parsePDLAST(tmpContext, sourceMgr, &lspCompleteContext);
+  (void)parsePDLLAST(tmpContext, sourceMgr, /*enableDocumentation=*/true,
+                     &lspCompleteContext);
 
   return completionList;
 }
@@ -954,10 +1022,11 @@ PDLDocument::getCodeCompletion(const lsp::URIForFile &uri,
 namespace {
 class LSPSignatureHelpContext : public CodeCompleteContext {
 public:
-  LSPSignatureHelpContext(SMLoc completeLoc, lsp::SignatureHelp &signatureHelp,
+  LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
+                          lsp::SignatureHelp &signatureHelp,
                           ods::Context &odsContext)
-      : CodeCompleteContext(completeLoc), signatureHelp(signatureHelp),
-        odsContext(odsContext) {}
+      : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
+        signatureHelp(signatureHelp), odsContext(odsContext) {}
 
   void codeCompleteCallSignature(const ast::CallableDecl *callable,
                                  unsigned currentNumArgs) final {
@@ -978,6 +1047,11 @@ class LSPSignatureHelpContext : public CodeCompleteContext {
       llvm::interleaveComma(callable->getInputs(), strOS, formatParamFn);
       strOS << ") -> " << callable->getResultType();
     }
+
+    // Format the documentation for the callable.
+    if (Optional<std::string> doc = getDocumentationFor(sourceMgr, callable))
+      signatureInfo.documentation = std::move(*doc);
+
     signatureHelp.signatures.emplace_back(std::move(signatureInfo));
   }
 
@@ -1069,6 +1143,7 @@ class LSPSignatureHelpContext : public CodeCompleteContext {
   }
 
 private:
+  llvm::SourceMgr &sourceMgr;
   lsp::SignatureHelp &signatureHelp;
   ods::Context &odsContext;
 };
@@ -1084,10 +1159,12 @@ lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri,
   // code completion context provided.
   ods::Context tmpODSContext;
   lsp::SignatureHelp signatureHelp;
-  LSPSignatureHelpContext completeContext(posLoc, signatureHelp, tmpODSContext);
+  LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
+                                          tmpODSContext);
 
   ast::Context tmpContext(tmpODSContext);
-  (void)parsePDLAST(tmpContext, sourceMgr, &completeContext);
+  (void)parsePDLLAST(tmpContext, sourceMgr, /*enableDocumentation=*/true,
+                     &completeContext);
 
   return signatureHelp;
 }

diff  --git a/mlir/test/mlir-pdll-lsp-server/hover.test b/mlir/test/mlir-pdll-lsp-server/hover.test
index 0548bb1944b49..22673fed2bd9f 100644
--- a/mlir/test/mlir-pdll-lsp-server/hover.test
+++ b/mlir/test/mlir-pdll-lsp-server/hover.test
@@ -5,13 +5,13 @@
   "uri":"test:///foo.pdll",
   "languageId":"pdll",
   "version":1,
-  "text":"Constraint FooCst();\nRewrite FooRewrite(op: Op) -> Op;\nPattern Foo {\nlet root: Op;\nerase root;\n}\n#include \"include/included.td\"\n#include \"include/included.pdll\""
+  "text":"Constraint FooCst();\n// This is documentation for the rewriter.\n/// And even more docs.\nRewrite FooRewrite(op: Op) -> Op;\nPattern Foo {\nlet root: Op;\nerase root;\n}\n#include \"include/included.td\"\n#include \"include/included.pdll\""
 }}}
 // -----
 // Hover on a variable.
 {"jsonrpc":"2.0","id":1,"method":"textDocument/hover","params":{
   "textDocument":{"uri":"test:///foo.pdll"},
-  "position":{"line":3,"character":6}
+  "position":{"line":5,"character":6}
 }}
 //      CHECK:  "id": 1,
 // CHECK-NEXT:  "jsonrpc": "2.0",
@@ -23,11 +23,11 @@
 // CHECK-NEXT:    "range": {
 // CHECK-NEXT:      "end": {
 // CHECK-NEXT:        "character": 8,
-// CHECK-NEXT:        "line": 3
+// CHECK-NEXT:        "line": 5
 // CHECK-NEXT:      },
 // CHECK-NEXT:      "start": {
 // CHECK-NEXT:        "character": 4,
-// CHECK-NEXT:        "line": 3
+// CHECK-NEXT:        "line": 5
 // CHECK-NEXT:      }
 // CHECK-NEXT:    }
 // CHECK-NEXT:  }
@@ -35,7 +35,7 @@
 // Hover on a pattern.
 {"jsonrpc":"2.0","id":1,"method":"textDocument/hover","params":{
   "textDocument":{"uri":"test:///foo.pdll"},
-  "position":{"line":2,"character":9}
+  "position":{"line":4,"character":9}
 }}
 //      CHECK:  "id": 1,
 // CHECK-NEXT:  "jsonrpc": "2.0",
@@ -47,11 +47,11 @@
 // CHECK-NEXT:    "range": {
 // CHECK-NEXT:      "end": {
 // CHECK-NEXT:        "character": 11,
-// CHECK-NEXT:        "line": 2
+// CHECK-NEXT:        "line": 4
 // CHECK-NEXT:      },
 // CHECK-NEXT:      "start": {
 // CHECK-NEXT:        "character": 8,
-// CHECK-NEXT:        "line": 2
+// CHECK-NEXT:        "line": 4
 // CHECK-NEXT:      }
 // CHECK-NEXT:    }
 // CHECK-NEXT:  }
@@ -59,7 +59,7 @@
 // Hover on a core constraint.
 {"jsonrpc":"2.0","id":1,"method":"textDocument/hover","params":{
   "textDocument":{"uri":"test:///foo.pdll"},
-  "position":{"line":3,"character":11}
+  "position":{"line":5,"character":11}
 }}
 //      CHECK:  "id": 1,
 // CHECK-NEXT:  "jsonrpc": "2.0",
@@ -71,11 +71,11 @@
 // CHECK-NEXT:    "range": {
 // CHECK-NEXT:      "end": {
 // CHECK-NEXT:        "character": 12,
-// CHECK-NEXT:        "line": 3
+// CHECK-NEXT:        "line": 5
 // CHECK-NEXT:      },
 // CHECK-NEXT:      "start": {
 // CHECK-NEXT:        "character": 10,
-// CHECK-NEXT:        "line": 3
+// CHECK-NEXT:        "line": 5
 // CHECK-NEXT:      }
 // CHECK-NEXT:    }
 // CHECK-NEXT:  }
@@ -107,23 +107,23 @@
 // Hover on a user rewrite.
 {"jsonrpc":"2.0","id":1,"method":"textDocument/hover","params":{
   "textDocument":{"uri":"test:///foo.pdll"},
-  "position":{"line":1,"character":11}
+  "position":{"line":3,"character":11}
 }}
 //      CHECK:  "id": 1,
 // CHECK-NEXT:  "jsonrpc": "2.0",
 // CHECK-NEXT:  "result": {
 // CHECK-NEXT:    "contents": {
 // CHECK-NEXT:      "kind": "markdown",
-// CHECK-NEXT:      "value": "**Rewrite**: `FooRewrite`\n***\nParameters:\n* op: `Op`\n***\nResults:\n* `Op`\n***\n"
+// CHECK-NEXT:      "value": "**Rewrite**: `FooRewrite`\n***\nParameters:\n* op: `Op`\n***\nResults:\n* `Op`\n***\n\n This is documentation for the rewriter.\n And even more docs.\n"
 // CHECK-NEXT:    },
 // CHECK-NEXT:    "range": {
 // CHECK-NEXT:      "end": {
 // CHECK-NEXT:        "character": 18,
-// CHECK-NEXT:        "line": 1
+// CHECK-NEXT:        "line": 3
 // CHECK-NEXT:      },
 // CHECK-NEXT:      "start": {
 // CHECK-NEXT:        "character": 8,
-// CHECK-NEXT:        "line": 1
+// CHECK-NEXT:        "line": 3
 // CHECK-NEXT:      }
 // CHECK-NEXT:    }
 // CHECK-NEXT:  }
@@ -131,7 +131,7 @@
 // Hover on an include file.
 {"jsonrpc":"2.0","id":1,"method":"textDocument/hover","params":{
   "textDocument":{"uri":"test:///foo.pdll"},
-  "position":{"line":6,"character":15}
+  "position":{"line":8,"character":15}
 }}
 //      CHECK:  "id": 1,
 // CHECK-NEXT:  "jsonrpc": "2.0",
@@ -143,11 +143,11 @@
 // CHECK-NEXT:    "range": {
 // CHECK-NEXT:      "end": {
 // CHECK-NEXT:        "character": 30,
-// CHECK-NEXT:        "line": 6
+// CHECK-NEXT:        "line": 8
 // CHECK-NEXT:      },
 // CHECK-NEXT:      "start": {
 // CHECK-NEXT:        "character": 9,
-// CHECK-NEXT:        "line": 6
+// CHECK-NEXT:        "line": 8
 // CHECK-NEXT:      }
 // CHECK-NEXT:    }
 // CHECK-NEXT:  }
@@ -155,7 +155,7 @@
 // Hover on an include file.
 {"jsonrpc":"2.0","id":1,"method":"textDocument/hover","params":{
   "textDocument":{"uri":"test:///foo.pdll"},
-  "position":{"line":7,"character":15}
+  "position":{"line":9,"character":15}
 }}
 //      CHECK:  "id": 1,
 // CHECK-NEXT:  "jsonrpc": "2.0",
@@ -167,11 +167,11 @@
 // CHECK-NEXT:    "range": {
 // CHECK-NEXT:      "end": {
 // CHECK-NEXT:        "character": 32,
-// CHECK-NEXT:        "line": 7
+// CHECK-NEXT:        "line": 9
 // CHECK-NEXT:      },
 // CHECK-NEXT:      "start": {
 // CHECK-NEXT:        "character": 9,
-// CHECK-NEXT:        "line": 7
+// CHECK-NEXT:        "line": 9
 // CHECK-NEXT:      }
 // CHECK-NEXT:    }
 // CHECK-NEXT:  }

diff  --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp
index 648b8f41b0cd0..f7f6f69105829 100644
--- a/mlir/tools/mlir-pdll/mlir-pdll.cpp
+++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp
@@ -43,9 +43,14 @@ processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
   sourceMgr.setIncludeDirs(includeDirs);
   sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), SMLoc());
 
+  // If we are dumping ODS information, also enable documentation to ensure the
+  // summary and description information is imported as well.
+  bool enableDocumentation = dumpODS;
+
   ods::Context odsContext;
   ast::Context astContext(odsContext);
-  FailureOr<ast::Module *> module = parsePDLAST(astContext, sourceMgr);
+  FailureOr<ast::Module *> module =
+      parsePDLLAST(astContext, sourceMgr, enableDocumentation);
   if (failed(module))
     return failure();
 


        


More information about the Mlir-commits mailing list