[llvm-branch-commits] [mlir] 6b9fa8a - [mlir][linalg] Add docstring support for named op spec

Lei Zhang via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 14 07:03:10 PST 2021


Author: Lei Zhang
Date: 2021-01-14T09:57:56-05:00
New Revision: 6b9fa8a50d0f9e1e54f238b1c50fee8ff7011218

URL: https://github.com/llvm/llvm-project/commit/6b9fa8a50d0f9e1e54f238b1c50fee8ff7011218
DIFF: https://github.com/llvm/llvm-project/commit/6b9fa8a50d0f9e1e54f238b1c50fee8ff7011218.diff

LOG: [mlir][linalg] Add docstring support for named op spec

Depends on D94335

Reviewed By: nicolasvasilache, hanchung

Differential Revision: https://reviews.llvm.org/D94548

Added: 
    

Modified: 
    mlir/docs/Dialects/Linalg.md
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
index 1f8ef3c4021b..a5caabd212b4 100644
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -608,10 +608,18 @@ semantics:
     perform multiple updates.
 2.  Each tensor may only be used with a single indexing expression.
 
+A `"""`-wrapped doc string can be attached to the named op. It should contain a
+oneliner for summary first, followed by lengthy description.
+
 The following specification may be used to define a named `batchmatmul` op:
 
 ```
-def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
+def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N))
+"""Batch matrix-multiply operation.
+
+This operation performs batch matrix-multiply over ...
+"""
+{
   C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
 }
 ```

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index 1ce2d2ac9418..226a09669b1c 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -125,3 +125,22 @@ def test5(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F))
   O(n, h, w, f) = std_addf<kh, kw>(std_mulf(
     I(n, h * strides[0] + kh, w * strides[1] + kw, c), K(f, kh, kw, c)));
 }
+
+// ODS-LABEL: def Test6Op
+// ODS:       let summary = [{ My magic op. }];
+// ODS-NEXT:  let description = [{
+// ODS-NEXT:    It has two inputs.
+// ODS-NEXT:    It has one output.
+// ODS-NEXT:  }];
+//
+ods_def<Test6Op>:
+def test6(A: f32(M, K), B: f32(K)) -> (C: f32(M))
+"""
+My magic op.
+
+It has two inputs.
+It has one output.
+"""
+{
+  C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
+}

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index cb7bfd2c9c4d..f4b7f9f9323a 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -30,6 +30,7 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/ToolOutputFile.h"
 
@@ -85,6 +86,7 @@ class Token {
     // Tokens with no info.
     colon,
     comma,
+    doc_str,
     equal,
     gt,
     l_brace,
@@ -183,6 +185,9 @@ class Lexer {
   // Lex an integer.
   Token lexInteger(const char *tokStart);
 
+  // Lex a string.
+  Token lexString(const char *tokStart);
+
   // Skip a comment line, starting with a '//'.
   void skipComment();
 
@@ -287,6 +292,8 @@ Token Lexer::lexToken() {
       return formToken(Token::Kind::star, tokStart);
     case '?':
       return formToken(Token::Kind::question, tokStart);
+    case '"':
+      return lexString(tokStart);
     case '/':
       if (*curPtr == '/') {
         skipComment();
@@ -333,6 +340,36 @@ Token Lexer::lexInteger(const char *tokStart) {
   return Token(Token::Kind::integer, str);
 }
 
+Token Lexer::lexString(const char *tokStart) {
+  assert(curPtr[-1] == '"');
+
+  if (*curPtr == '"' && *(curPtr + 1) == '"') {
+    curPtr += 2;
+    while (true) {
+      switch (*curPtr++) {
+      case '"':
+        if (*curPtr == '"' && *(curPtr + 1) == '"') {
+          Token token(Token::Kind::doc_str,
+                      StringRef(tokStart + 3, curPtr - tokStart - 4));
+          curPtr += 2;
+          return token;
+        }
+        continue;
+      case 0:
+        // If this is a random nul character in the middle of the doc string,
+        // just include it.  If it is the end of file, then it is an error.
+        if (curPtr - 1 != curBuffer.end())
+          continue;
+        return emitError(curPtr - 1, "expected '\"\"\"' to end doc string");
+      default:
+        continue;
+      }
+    }
+  }
+
+  return emitError(curPtr - 1, "expected '\"\"\"' to start doc string");
+}
+
 /// Skip a comment line, starting with a '//'.
 void Lexer::skipComment() {
   // Advance over the second '/' in a '//' comment.
@@ -1134,6 +1171,8 @@ class TCParser {
   /// Attributes are per TC def.
   std::map<std::string, RegisteredAttr> registeredAttrs;
 
+  StringRef docString;
+
   Parser &parser;
 };
 } // namespace
@@ -1655,6 +1694,14 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
       return failure();
   }
 
+  // Parse optional doc string
+  if (parser.curToken.is(Token::Kind::doc_str)) {
+    docString = parser.curToken.getSpelling();
+    parser.consumeToken();
+    LLVM_DEBUG(llvm::dbgs()
+               << "parsed doc string: '''" << docString << "'''\n");
+  }
+
   // Since we don't declare symbols separately, we discover them eagerly: each
   // newly encountered id in a tensor shape expression is treated as a new
   // symbolic. At this point, all tensors have been parsed and all the symbols
@@ -1755,9 +1802,10 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
     AttrSizedOperandSegments,
     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
     SingleBlockImplicitTerminator<"YieldOp">]> {
+      {2}
       let arguments = (ins
         Variadic<AnyShaped>:$inputs,
-        Variadic<AnyShaped>:$outputs{4}
+        Variadic<AnyShaped>:$outputs{3}
       );
       let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
       let regions = (region AnyRegion:$region);
@@ -1818,23 +1866,30 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
 
         // Generic methods.
-        static unsigned getNumRegionArgs() {{ return {5}; }
+        static unsigned getNumRegionArgs() {{ return {4}; }
         std::string getLibraryCallName() {{
           return generateLibraryCallName(getOperation());
         }
       }];
   })FMT";
 
-  unsigned nInputs = 0, nOutputs = 0;
-  for (auto &t : registeredTensors) {
-    if (t.getValue().isOutput)
-      nOutputs++;
-    else
-      nInputs++;
+  std::string doc;
+
+  if (!docString.empty()) {
+    const char *docFmt = R"FMT(
+      let summary = [{ {0} }];
+      let description = [{
+        {1}
+      }];
+    )FMT";
+
+    StringRef summary, description;
+    std::tie(summary, description) = docString.trim().split('\n');
+    doc = llvm::formatv(docFmt, summary.trim(), description.trim());
   }
 
-  os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs,
-                      attrList, state.orderedTensorArgs.size());
+  os << llvm::formatv(header, cppOpName, linalgOpName, doc, attrList,
+                      state.orderedTensorArgs.size());
 }
 
 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.


        


More information about the llvm-branch-commits mailing list