[llvm-branch-commits] [mlir] 6b9fa8a - [mlir][linalg] Add docstring support for named op spec
Lei Zhang via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 14 07:03:10 PST 2021
Author: Lei Zhang
Date: 2021-01-14T09:57:56-05:00
New Revision: 6b9fa8a50d0f9e1e54f238b1c50fee8ff7011218
URL: https://github.com/llvm/llvm-project/commit/6b9fa8a50d0f9e1e54f238b1c50fee8ff7011218
DIFF: https://github.com/llvm/llvm-project/commit/6b9fa8a50d0f9e1e54f238b1c50fee8ff7011218.diff
LOG: [mlir][linalg] Add docstring support for named op spec
Depends on D94335
Reviewed By: nicolasvasilache, hanchung
Differential Revision: https://reviews.llvm.org/D94548
Added:
Modified:
mlir/docs/Dialects/Linalg.md
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
index 1f8ef3c4021b..a5caabd212b4 100644
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -608,10 +608,18 @@ semantics:
perform multiple updates.
2. Each tensor may only be used with a single indexing expression.
+A `"""`-wrapped doc string can be attached to the named op. It should contain a
+oneliner for summary first, followed by lengthy description.
+
The following specification may be used to define a named `batchmatmul` op:
```
-def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
+def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N))
+"""Batch matrix-multiply operation.
+
+This operation performs batch matrix-multiply over ...
+"""
+{
C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
}
```
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index 1ce2d2ac9418..226a09669b1c 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -125,3 +125,22 @@ def test5(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F))
O(n, h, w, f) = std_addf<kh, kw>(std_mulf(
I(n, h * strides[0] + kh, w * strides[1] + kw, c), K(f, kh, kw, c)));
}
+
+// ODS-LABEL: def Test6Op
+// ODS: let summary = [{ My magic op. }];
+// ODS-NEXT: let description = [{
+// ODS-NEXT: It has two inputs.
+// ODS-NEXT: It has one output.
+// ODS-NEXT: }];
+//
+ods_def<Test6Op>:
+def test6(A: f32(M, K), B: f32(K)) -> (C: f32(M))
+"""
+My magic op.
+
+It has two inputs.
+It has one output.
+"""
+{
+ C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
+}
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index cb7bfd2c9c4d..f4b7f9f9323a 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -30,6 +30,7 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ToolOutputFile.h"
@@ -85,6 +86,7 @@ class Token {
// Tokens with no info.
colon,
comma,
+ doc_str,
equal,
gt,
l_brace,
@@ -183,6 +185,9 @@ class Lexer {
// Lex an integer.
Token lexInteger(const char *tokStart);
+ // Lex a string.
+ Token lexString(const char *tokStart);
+
// Skip a comment line, starting with a '//'.
void skipComment();
@@ -287,6 +292,8 @@ Token Lexer::lexToken() {
return formToken(Token::Kind::star, tokStart);
case '?':
return formToken(Token::Kind::question, tokStart);
+ case '"':
+ return lexString(tokStart);
case '/':
if (*curPtr == '/') {
skipComment();
@@ -333,6 +340,36 @@ Token Lexer::lexInteger(const char *tokStart) {
return Token(Token::Kind::integer, str);
}
+Token Lexer::lexString(const char *tokStart) {
+ assert(curPtr[-1] == '"');
+
+ if (*curPtr == '"' && *(curPtr + 1) == '"') {
+ curPtr += 2;
+ while (true) {
+ switch (*curPtr++) {
+ case '"':
+ if (*curPtr == '"' && *(curPtr + 1) == '"') {
+ Token token(Token::Kind::doc_str,
+ StringRef(tokStart + 3, curPtr - tokStart - 4));
+ curPtr += 2;
+ return token;
+ }
+ continue;
+ case 0:
+ // If this is a random nul character in the middle of the doc string,
+ // just include it. If it is the end of file, then it is an error.
+ if (curPtr - 1 != curBuffer.end())
+ continue;
+ return emitError(curPtr - 1, "expected '\"\"\"' to end doc string");
+ default:
+ continue;
+ }
+ }
+ }
+
+ return emitError(curPtr - 1, "expected '\"\"\"' to start doc string");
+}
+
/// Skip a comment line, starting with a '//'.
void Lexer::skipComment() {
// Advance over the second '/' in a '//' comment.
@@ -1134,6 +1171,8 @@ class TCParser {
/// Attributes are per TC def.
std::map<std::string, RegisteredAttr> registeredAttrs;
+ StringRef docString;
+
Parser &parser;
};
} // namespace
@@ -1655,6 +1694,14 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
return failure();
}
+ // Parse optional doc string
+ if (parser.curToken.is(Token::Kind::doc_str)) {
+ docString = parser.curToken.getSpelling();
+ parser.consumeToken();
+ LLVM_DEBUG(llvm::dbgs()
+ << "parsed doc string: '''" << docString << "'''\n");
+ }
+
// Since we don't declare symbols separately, we discover them eagerly: each
// newly encountered id in a tensor shape expression is treated as a new
// symbolic. At this point, all tensors have been parsed and all the symbols
@@ -1755,9 +1802,10 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">]> {
+ {2}
let arguments = (ins
Variadic<AnyShaped>:$inputs,
- Variadic<AnyShaped>:$outputs{4}
+ Variadic<AnyShaped>:$outputs{3}
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
@@ -1818,23 +1866,30 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
// Generic methods.
- static unsigned getNumRegionArgs() {{ return {5}; }
+ static unsigned getNumRegionArgs() {{ return {4}; }
std::string getLibraryCallName() {{
return generateLibraryCallName(getOperation());
}
}];
})FMT";
- unsigned nInputs = 0, nOutputs = 0;
- for (auto &t : registeredTensors) {
- if (t.getValue().isOutput)
- nOutputs++;
- else
- nInputs++;
+ std::string doc;
+
+ if (!docString.empty()) {
+ const char *docFmt = R"FMT(
+ let summary = [{ {0} }];
+ let description = [{
+ {1}
+ }];
+ )FMT";
+
+ StringRef summary, description;
+ std::tie(summary, description) = docString.trim().split('\n');
+ doc = llvm::formatv(docFmt, summary.trim(), description.trim());
}
- os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs,
- attrList, state.orderedTensorArgs.size());
+ os << llvm::formatv(header, cppOpName, linalgOpName, doc, attrList,
+ state.orderedTensorArgs.size());
}
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
More information about the llvm-branch-commits
mailing list