[Mlir-commits] [mlir] 882ba48 - [mlir][Linalg] Create a tool to generate named Linalg ops from a Tensor Comprehensions-like specification.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Apr 10 11:02:50 PDT 2020
Author: Nicolas Vasilache
Date: 2020-04-10T13:59:25-04:00
New Revision: 882ba484743763b8560b08f483ae21d26ab336f9
URL: https://github.com/llvm/llvm-project/commit/882ba484743763b8560b08f483ae21d26ab336f9
DIFF: https://github.com/llvm/llvm-project/commit/882ba484743763b8560b08f483ae21d26ab336f9.diff
LOG: [mlir][Linalg] Create a tool to generate named Linalg ops from a Tensor Comprehensions-like specification.
Summary:
This revision adds a tool that generates the ODS and C++ implementation for "named" Linalg ops according to the [RFC discussion](https://llvm.discourse.group/t/rfc-declarative-named-ops-in-the-linalg-dialect/745).
While the mechanisms and language aspects are by no means set in stone, this revision allows connecting the pieces end-to-end from a mathematical-like specification.
Some implementation details and short-term decisions taken for the purpose of bootstrapping and that are not set in stone include:
1. using a "[Tensor Comprehension](https://arxiv.org/abs/1802.04730)-inspired" syntax
2. implicit and eager discovery of dims and symbols when parsing
3. using EDSC ops to specify the computation (e.g. std_addf, std_mul_f, ...)
A followup revision will connect this tool to tablegen mechanisms and allow the emission of named Linalg ops that automatically lower to various loop forms and run end to end.
For the following "Tensor Comprehension-inspired" string:
```
def batch_matmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
}
```
With -gen-ods-decl=1, this emits (modulo formatting):
```
def batch_matmulOp : LinalgNamedStructured_Op<"batch_matmul", [
NInputs<2>,
NOutputs<1>,
NamedStructuredOpTraits]> {
let arguments = (ins Variadic<LinalgOperand>:$views);
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
let extraClassDeclaration = [{
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators();
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps();
void regionBuilder(ArrayRef<BlockArgument> args);
}];
let hasFolder = 1;
}
```
With -gen-ods-impl, this emits (modulo formatting):
```
llvm::Optional<SmallVector<StringRef, 8>> batch_matmul::referenceIterators() {
return SmallVector<StringRef, 8>{ getParallelIteratorTypeName(),
getParallelIteratorTypeName(),
getParallelIteratorTypeName(),
getReductionIteratorTypeName() };
}
llvm::Optional<SmallVector<AffineMap, 8>> batch_matmul::referenceIndexingMaps()
{
MLIRContext *context = getContext();
AffineExpr d0, d1, d2, d3;
bindDims(context, d0, d1, d2, d3);
return SmallVector<AffineMap, 8>{
AffineMap::get(4, 0, {d0, d1, d3}),
AffineMap::get(4, 0, {d3, d2}),
AffineMap::get(4, 0, {d0, d1, d2}) };
}
void batch_matmul::regionBuilder(ArrayRef<BlockArgument> args) {
using namespace edsc;
using namespace intrinsics;
ValueHandle _0(args[0]), _1(args[1]), _2(args[2]);
ValueHandle _4 = std_mulf(_0, _1);
ValueHandle _5 = std_addf(_2, _4);
(linalg_yield(ValueRange{ _5 }));
}
```
Differential Revision: https://reviews.llvm.org/D77067
Added:
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
Modified:
mlir/docs/Dialects/Linalg.md
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/IR/AffineExpr.h
mlir/lib/IR/AffineExpr.cpp
mlir/test/CMakeLists.txt
mlir/test/lit.cfg.py
mlir/tools/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
index 878ce8f11523..9fb5d54f9984 100644
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -451,6 +451,93 @@ from a description in terms of only the generic op interface.
This is the main reason there are only a small number of ops today: we expect
them to be auto-generated from Tablegen soon.
+### Named Payload Ops Specification
+
+Linalg provides a declarative specification and a generation tool
+(`mlir-linalg-ods-gen`) to automatically produce named ops from a notation that
+is inspired by Einstein notation.
+
+The syntax and semantics used in `mlir-linalg-ods-gen` are very much in flight
+and borrow from Tensor Comprehensions (TC) but
diff er in a few dimensions, to
+better adapt to Linalg:
+
+1. The input and output tensor parameters are specified as `id :
+ type(symbolic-affine-expression-list)` (e.g. `A : f32(M, N + M)`) and each
+ new symbol is discovered eagerly. TC on the other hand does not allow
+ general symbolic affine expressions.
+1. The output shapes are specified explicitly, in TC they are always derived
+ from the input shapes.
+1. The operations used to specify computations use EDSC intrinsics so that they
+ can easily be parsed and emitted into a simple region builder without
+ resorting to more general MLIR parsing.
+1. Reduction dimensions are specified with angle bracket notation on the
+ operation they apply to (e.g. `std_add<k>` specifies that `k` is a reduction
+ dimension). In TC, a reduction is specified with `op=` operator and the
+ reduction dimensions are inferred.
+1. The parallel and reduction dimension are ordered by the textual program
+ order. For instance, in the comprehension `O(i, j) = std_add<k, l>(...)`,
+ `i` (resp. `j`) is a parallel iterator encoded by affine dimension of
+ position `0` (resp. `1`); `k` (resp. `l`) is a reduction iterator encoded by
+ an affine dimension of position `2` (resp. `3`).
+
+These decisions and syntax are subject to evolution and change. In particular,
+op-specific attributes, dynamic ranks, some form of templating, shape
+calculation function specification, etc. may be added in the future.
+
+At this time, the following restrictions are imposed on the syntax and
+semantics:
+
+1. Each def may only contain a single comprehension but each comprehension may
+ perform multiple updates.
+2. Each tensor may only be used with a single indexing expression.
+
+The following specification may be used to define a named `batchmatmul` op:
+
+```
+def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
+ C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
+}
+```
+
+When `mlir-linalg-ods-gen -gen-ods-decl=1` is called, the following ODS is
+produced:
+
+```
+ def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [
+ NInputs<2>,
+ NOutputs<1>,
+ NamedStructuredOpTraits]> { ... }
+```
+
+When `mlir-linalg-ods-gen -gen-impl=1` is called, the following C++ is produced:
+
+```
+llvm::Optional<SmallVector<StringRef, 8>> batchmatmul::referenceIterators() {
+ return SmallVector<StringRef, 8>{
+ getParallelIteratorTypeName(),
+ getParallelIteratorTypeName(),
+ getParallelIteratorTypeName(),
+ getReductionIteratorTypeName() };
+}
+llvm::Optional<SmallVector<AffineMap, 8>> batchmatmul::referenceIndexingMaps() {
+ MLIRContext *context = getContext();
+ AffineExpr d0, d1, d2, d3;
+ bindDims(context, d0, d1, d2, d3);
+ return SmallVector<AffineMap, 8>{
+ AffineMap::get(4, 0, {d0, d1, d3}),
+ AffineMap::get(4, 0, {d3, d2}),
+ AffineMap::get(4, 0, {d0, d1, d2}) };
+}
+void batchmatmul::regionBuilder(ArrayRef<BlockArgument> args) {
+ using namespace edsc;
+ using namespace intrinsics;
+ ValueHandle _0(args[0]), _1(args[1]), _2(args[2]);
+ ValueHandle _4 = std_mulf(_0, _1);
+ ValueHandle _5 = std_addf(_2, _4);
+ (linalg_yield(ValueRange{ _5 }));
+}
+```
+
## Open Issues and Design Alternatives<a name="open_issues"></a>
Multiple open issues and design alternatives are in flight and it is time to
lay them out for the community to discuss and pick apart:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fd3770af3592..641039afd15d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -256,7 +256,7 @@ def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
/// OptionalAttr<I64ArrayAttr>:$strides
/// OptionalAttr<I64ArrayAttr>:$dilations
/// OptionalAttr<I64ElementsAttr>:$padding
-/// `strides` denotes the step of each window along the dimension.
+/// `stirdes` denotes the step of each window along the dimension.
class PoolingBase_Op<string mnemonic, list<OpTrait> props>
: LinalgStructured_Op<mnemonic, props> {
let description = [{
@@ -821,4 +821,18 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// Named Linalg ops, implemented as a declarative configurations of generic ops.
+//===----------------------------------------------------------------------===//
+
+def NamedStructuredOpTraits : NativeOpTrait<"linalg::NamedStructuredOpTraits">;
+
+class LinalgNamedStructured_Op<string mnemonic, list<OpTrait> props>
+ : Op<Linalg_Dialect, mnemonic,
+ !listconcat(props, [StructuredOpTraits, LinalgStructuredInterface])> {
+ string spec = ?;
+ let assemblyFormat = "`(` operands `)` attr-dict `:` "
+ "functional-type(operands, results)";
+}
+
#endif // LINALG_STRUCTURED_OPS
diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 5d3e86bc9ba1..2302abd8554d 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -219,7 +219,7 @@ AffineExpr getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
ArrayRef<AffineExpr> localExprs,
MLIRContext *context);
-raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr);
+raw_ostream &operator<<(raw_ostream &os, AffineExpr expr);
template <typename U> bool AffineExpr::isa() const {
if (std::is_same<U, AffineBinaryOpExpr>::value)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 295b5155c29b..5d6ec2f1ed8e 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -613,7 +613,7 @@ AffineExpr AffineExpr::compose(AffineMap map) const {
map.getResults().end());
return replaceDimsAndSymbols(dimReplacements, {});
}
-raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr &expr) {
+raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
expr.print(os);
return os;
}
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 65235fee23b5..b8d2a6e05594 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -35,6 +35,7 @@ set(MLIR_TEST_DEPENDS
MLIRUnitTests
mlir-cpu-runner
mlir-edsc-builder-api-test
+ mlir-linalg-ods-gen
mlir-opt
mlir-sdbm-api-test
mlir-tblgen
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 84a4de057c13..65f80315d57a 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -21,7 +21,7 @@
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
-config.suffixes = ['.td', '.mlir', '.toy', '.ll']
+config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc']
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
new file mode 100644
index 000000000000..680d3ee28f80
--- /dev/null
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -0,0 +1,75 @@
+// RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 | FileCheck %s --check-prefix=ODS
+// RUN: mlir-linalg-ods-gen %s -gen-impl=1 | FileCheck %s --check-prefix=IMPL
+
+// RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 -test-emit-include-td-header \
+// RUN: | mlir-tblgen -gen-op-decls -I %S/../../include
+
+// ODS-LABEL: def matvecOp : LinalgNamedStructured_Op<"matvec", [
+// ODS-NEXT: NInputs<2>,
+// ODS-NEXT: NOutputs<1>,
+// ODS-NEXT: NamedStructuredOpTraits]>
+//
+// IMPL-LABEL: matvec::referenceIterators() {
+// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
+//
+// IMPL: matvec::referenceIndexingMaps() {
+// IMPL: AffineMap::get(2, 0, {d0, d1}),
+// IMPL-NEXT: AffineMap::get(2, 0, {d1}),
+// IMPL-NEXT: AffineMap::get(2, 0, {d0}) };
+//
+// IMPL: matvec::regionBuilder(ArrayRef<BlockArgument> args) {
+// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
+// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
+// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
+// IMPL: (linalg_yield(ValueRange{ [[e]] }));
+//
+def matvec(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
+ C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
+}
+
+// ODS-LABEL: def matmulOp : LinalgNamedStructured_Op<"matmul", [
+// ODS-NEXT: NInputs<2>,
+// ODS-NEXT: NOutputs<1>,
+// ODS-NEXT: NamedStructuredOpTraits]>
+//
+// IMPL-LABEL: matmul::referenceIterators() {
+// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
+//
+// IMPL: matmul::referenceIndexingMaps() {
+// IMPL: AffineMap::get(3, 0, {d0, d2}),
+// IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}),
+// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}) };
+//
+// IMPL: matmul::regionBuilder(ArrayRef<BlockArgument> args) {
+// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
+// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
+// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
+// IMPL: (linalg_yield(ValueRange{ [[e]] }));
+//
+def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
+ C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
+}
+
+// ODS-LABEL: def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [
+// ODS-NEXT: NInputs<2>,
+// ODS-NEXT: NOutputs<1>,
+// ODS-NEXT: NamedStructuredOpTraits]>
+//
+// IMPL-LABEL: batchmatmul::referenceIterators() {
+// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
+//
+// IMPL: batchmatmul::referenceIndexingMaps() {
+// IMPL: AffineMap::get(4, 0, {d0, d1, d3}),
+// IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}),
+// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}) };
+//
+// IMPL: batchmatmul::regionBuilder(ArrayRef<BlockArgument> args) {
+// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
+// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
+// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
+// IMPL: (linalg_yield(ValueRange{ [[e]] }));
+//
+// TBLGEN: batchmatmulOp
+def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
+ C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
+}
diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt
index 5c0125c270d8..f01648bffec3 100644
--- a/mlir/tools/CMakeLists.txt
+++ b/mlir/tools/CMakeLists.txt
@@ -1,5 +1,6 @@
add_subdirectory(mlir-cuda-runner)
add_subdirectory(mlir-cpu-runner)
+add_subdirectory(mlir-linalg-ods-gen)
add_subdirectory(mlir-opt)
add_subdirectory(mlir-translate)
add_subdirectory(mlir-vulkan-runner)
diff --git a/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt b/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt
new file mode 100644
index 000000000000..b4fa6e35fc9a
--- /dev/null
+++ b/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_llvm_tool(mlir-linalg-ods-gen
+ mlir-linalg-ods-gen.cpp
+)
+llvm_update_compile_flags(mlir-linalg-ods-gen)
+target_link_libraries(mlir-linalg-ods-gen PRIVATE
+ MLIRParser
+ MLIRSupport
+ LLVMCore
+ LLVMSupport
+ )
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
new file mode 100644
index 000000000000..84e10d9eb327
--- /dev/null
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -0,0 +1,1659 @@
+//===- mlir-linalg-ods-gen.cpp - Linalg ODS generation from math form -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the implementation for the Tensor Comprehension-inspired
+// parser and ODS pretty-printer for specifying Linalg "named ops" from a
+// mathematical form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+#define DEBUG_TYPE "linalg-ods-gen"
+
+static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen");
+
+// Commandline options
+static llvm::cl::opt<std::string>
+ inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
+ llvm::cl::init("-"), llvm::cl::value_desc("filename"));
+
+static llvm::cl::opt<std::string>
+ outputFilename("o", llvm::cl::desc("Output filename"),
+ llvm::cl::value_desc("filename"), llvm::cl::init("-"));
+
+static llvm::cl::opt<bool>
+ genODSDecl("gen-ods-decl", llvm::cl::desc("Emit the ODS ops declarations."),
+ llvm::cl::cat(ODSGenCat));
+
+static llvm::cl::opt<bool>
+ genODSImpl("gen-impl", llvm::cl::desc("Emit the ops implementations"),
+ llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
+
+static llvm::cl::opt<bool> testEmitIncludeTdHeader(
+ "test-emit-include-td-header",
+ llvm::cl::desc("Include LinalgStructuredOps.td for end-to-end "
+ "tblgen testing."),
+ llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
+
+using llvm::SetVector;
+using llvm::SMLoc;
+using llvm::StringRef;
+using llvm::Twine;
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Lexer
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class represents a specific token in the input format.
+class Token {
+public:
+ enum class Kind {
+ // Markers.
+ eof,
+ error,
+
+ // Tokens with no info.
+ colon,
+ comma,
+ equal,
+ gt,
+ l_brace,
+ l_paren,
+ lt,
+ minus,
+ plus,
+ r_brace,
+ r_paren,
+ semicolon,
+ star,
+
+ // Keywords.
+ kw_def,
+ FIRST_KEYWORD = kw_def,
+ kw_floordiv,
+ kw_ceildiv,
+ kw_mod,
+ LAST_KEYWORD = kw_mod,
+
+ // String valued tokens.
+ id,
+ integer,
+ };
+
+ Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
+
+ /// Return the bytes that make up this token.
+ StringRef getSpelling() const { return spelling; }
+
+ /// Return the kind of this token.
+ Kind getKind() const { return kind; }
+
+ /// Return a location for this token.
+ llvm::SMLoc getLoc() const {
+ return llvm::SMLoc::getFromPointer(spelling.data());
+ }
+
+ /// Return if this token is a keyword.
+ bool isKeyword() const {
+ return kind >= Kind::FIRST_KEYWORD && kind <= Kind::LAST_KEYWORD;
+ }
+ bool is(Kind k) const { return kind == k; }
+ bool isNot(Kind k) const { return kind != k; }
+
+ Optional<uint64_t> getUInt64IntegerValue() const {
+ bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+
+ uint64_t result = 0;
+ if (spelling.getAsInteger(isHex ? 0 : 10, result))
+ return None;
+ return result;
+ }
+
+private:
+ /// Discriminator that indicates the kind of token this is.
+ Kind kind;
+
+ /// A reference to the entire token contents; this is always a pointer into
+ /// a memory buffer owned by the source manager.
+ StringRef spelling;
+};
+
+/// This class implements a simple lexer.
+class Lexer {
+public:
+ Lexer(llvm::SourceMgr &mgr);
+
+ /// Lex the next token and return it.
+ Token lexToken();
+
+ /// Emit an error to the lexer with the given location and message.
+ Token emitError(llvm::SMLoc loc, const Twine &msg);
+ Token emitError(const char *loc, const Twine &msg);
+
+private:
+ Token formToken(Token::Kind kind, const char *tokStart) {
+ return Token(kind, StringRef(tokStart, curPtr - tokStart));
+ }
+
+ /// Return the next character in the stream.
+ int getNextChar();
+
+ /// Lex an identifier.
+ Token lexIdentifier(const char *tokStart);
+
+ // Lex an integer.
+ Token lexInteger(const char *tokStart);
+
+ // Skip a comment line, starting with a '//'.
+ void skipComment();
+
+ llvm::SourceMgr &srcMgr;
+ StringRef curBuffer;
+ const char *curPtr;
+};
+} // end anonymous namespace
+
+Lexer::Lexer(llvm::SourceMgr &mgr) : srcMgr(mgr) {
+ curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
+ curPtr = curBuffer.begin();
+}
+
+Token Lexer::emitError(llvm::SMLoc loc, const Twine &msg) {
+ srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
+ return formToken(Token::Kind::error, loc.getPointer());
+}
+Token Lexer::emitError(const char *loc, const Twine &msg) {
+ return emitError(llvm::SMLoc::getFromPointer(loc), msg);
+}
+
+int Lexer::getNextChar() {
+ char curChar = *curPtr++;
+ switch (curChar) {
+ default:
+ return (unsigned char)curChar;
+ case 0: {
+ // A nul character in the stream is either the end of the current buffer
+ // or a random nul in the file. Disambiguate that here.
+ if (curPtr - 1 != curBuffer.end())
+ return 0;
+
+ // Otherwise, return end of file.
+ --curPtr;
+ return EOF;
+ }
+ case '\n':
+ case '\r':
+ // Handle the newline character by ignoring it and incrementing the line
+ // count. However, be careful about 'dos style' files with \n\r in them.
+ // Only treat a \n\r or \r\n as a single line.
+ if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
+ ++curPtr;
+ return '\n';
+ }
+}
+
+Token Lexer::lexToken() {
+ while (true) {
+ const char *tokStart = curPtr;
+
+ // This always consumes at least one character.
+ int curChar = getNextChar();
+ switch (curChar) {
+ default:
+ // Handle identifiers: [a-zA-Z_]
+ if (isalpha(curChar) || curChar == '_')
+ return lexIdentifier(tokStart);
+
+ // Handle integers: [0-9]
+ if (isdigit(curChar))
+ return lexInteger(tokStart);
+
+ // Unknown character, emit an error.
+ return emitError(tokStart, "unexpected character");
+
+ case EOF:
+ // Return EOF denoting the end of lexing.
+ return formToken(Token::Kind::eof, tokStart);
+
+ // Lex punctuation.
+ case ':':
+ return formToken(Token::Kind::colon, tokStart);
+ case ',':
+ return formToken(Token::Kind::comma, tokStart);
+ case '=':
+ return formToken(Token::Kind::equal, tokStart);
+ case '{':
+ return formToken(Token::Kind::l_brace, tokStart);
+ case '(':
+ return formToken(Token::Kind::l_paren, tokStart);
+ case '}':
+ return formToken(Token::Kind::r_brace, tokStart);
+ case ')':
+ return formToken(Token::Kind::r_paren, tokStart);
+ case '<':
+ return formToken(Token::Kind::lt, tokStart);
+ case '>':
+ return formToken(Token::Kind::gt, tokStart);
+ case '+':
+ return formToken(Token::Kind::plus, tokStart);
+ case '-':
+ return formToken(Token::Kind::minus, tokStart);
+ case ';':
+ return formToken(Token::Kind::semicolon, tokStart);
+ case '*':
+ return formToken(Token::Kind::star, tokStart);
+ case '/':
+ if (*curPtr == '/') {
+ skipComment();
+ continue;
+ }
+ // Unknown character, emit an error.
+ return emitError(tokStart, "unexpected character: not a comment");
+
+ // Ignore whitespace characters.
+ case 0:
+ case ' ':
+ case '\t':
+ case '\n':
+ return lexToken();
+ }
+ }
+}
+
+Token Lexer::lexIdentifier(const char *tokStart) {
+ // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
+ while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
+ ++curPtr;
+
+ // Check to see if this identifier is a keyword.
+ StringRef str(tokStart, curPtr - tokStart);
+ Token::Kind kind = llvm::StringSwitch<Token::Kind>(str)
+ .Case("def", Token::Kind::kw_def)
+ .Case("floordiv", Token::Kind::kw_floordiv)
+ .Case("ceildiv", Token::Kind::kw_ceildiv)
+ .Case("mod", Token::Kind::kw_mod)
+ .Default(Token::Kind::id);
+
+ return Token(kind, str);
+}
+
+Token Lexer::lexInteger(const char *tokStart) {
+ // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
+ while (isdigit(*curPtr))
+ ++curPtr;
+
+ StringRef str(tokStart, curPtr - tokStart);
+ return Token(Token::Kind::integer, str);
+}
+
+/// Skip a comment line, starting with a '//'.
+void Lexer::skipComment() {
+ // Advance over the second '/' in a '//' comment.
+ assert(*curPtr == '/');
+ ++curPtr;
+
+ while (true) {
+ switch (*curPtr++) {
+ case '\n':
+ case '\r':
+ // Newline is end of comment.
+ return;
+ case 0:
+ // If this is the end of the buffer, end the comment.
+ if (curPtr - 1 == curBuffer.end()) {
+ --curPtr;
+ return;
+ }
+ LLVM_FALLTHROUGH;
+ default:
+ // Skip over other characters.
+ break;
+ }
+ }
+}
+
+namespace {
+
+class Parser {
+public:
+ Parser(llvm::SourceMgr &mgr, MLIRContext *ctx)
+ : lexer(mgr), curToken(lexer.lexToken()), context(ctx) {}
+
+ //===--------------------------------------------------------------------===//
+ // Lexer Utilities
+ //===--------------------------------------------------------------------===//
+
+ /// Advance the current lexer onto the next token.
+ void consumeToken() {
+ assert(curToken.getKind() != Token::Kind::eof &&
+ curToken.getKind() != Token::Kind::error &&
+ "shouldn't advance past EOF or errors");
+ curToken = lexer.lexToken();
+ }
+ void consumeToken(Token::Kind kind) {
+ assert(curToken.getKind() == kind && "unexpected token");
+ curToken = lexer.lexToken();
+ }
+ LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
+ if (curToken.getKind() != kind)
+ return emitError(curToken.getLoc(), msg);
+ consumeToken();
+ return success();
+ }
+ LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
+ lexer.emitError(loc, msg);
+ return failure();
+ }
+ LogicalResult emitError(const Twine &msg) {
+ return emitError(curToken.getLoc(), msg);
+ }
+ bool consumeIf(Token::Kind kind) {
+ if (curToken.isNot(kind))
+ return false;
+ consumeToken(kind);
+ return true;
+ }
+ LogicalResult
+ parseCommaSeparatedList(llvm::function_ref<ParseResult()> parseElement) {
+ // Non-empty case starts with an element.
+ if (parseElement())
+ return failure();
+
+ // Otherwise we have a list of comma separated elements.
+ while (consumeIf(Token::Kind::comma)) {
+ if (parseElement())
+ return failure();
+ }
+ return success();
+ }
+ LogicalResult
+ parseCommaSeparatedListUntil(Token::Kind rightToken,
+ llvm::function_ref<ParseResult()> parseElement,
+ bool allowEmptyList) {
+ // Handle the empty case.
+ if (curToken.is(rightToken)) {
+ if (!allowEmptyList)
+ return emitError("expected list element");
+ consumeToken(rightToken);
+ return success();
+ }
+
+ if (failed(parseCommaSeparatedList(parseElement)) ||
+ failed(
+ parseToken(rightToken, "expected ',' or right-terminating token")))
+ return failure();
+
+ return success();
+ }
+
+ Lexer lexer;
+ Token curToken;
+ MLIRContext *context;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Affine parsing.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Lower precedence ops (all at the same precedence level). LNoOp is false in
+/// the boolean sense.
+enum AffineLowPrecOp {
+ /// Null value.
+ LNoOp,
+ Add,
+ Sub
+};
+
+/// Higher precedence ops - all at the same precedence level. HNoOp is false
+/// in the boolean sense.
+enum AffineHighPrecOp {
+ /// Null value.
+ HNoOp,
+ Mul,
+ FloorDiv,
+ CeilDiv,
+ Mod
+};
+
+using AffineDimList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
+using AffineSymbolList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
+
+/// This is a specialized parser for affine expressions.
+class AffineParser {
+public:
+ explicit AffineParser(Parser &p,
+ std::function<AffineExpr(StringRef)> bareIdParsingHook,
+ AffineDimList &dimList, AffineSymbolList &symbolList)
+ : parser(p), bareIdFallback(bareIdParsingHook), dims(dimList),
+ symbols(symbolList) {}
+
+ /// Parse a comma-separated list of affine exprs.
+ SmallVector<AffineExpr, 4>
+ parseAffineExprs(Token::Kind lDelim = Token::Kind::l_paren,
+ Token::Kind rDelim = Token::Kind::r_paren);
+
+ /// Parse a single affine expr.`.
+ AffineExpr parseAffineExpr();
+
+private:
+ // Binary affine op parsing.
+ AffineLowPrecOp consumeIfLowPrecOp();
+ AffineHighPrecOp consumeIfHighPrecOp();
+
+ // AffineExpr parsing.
+ AffineExpr parseParentheticalExpr();
+ AffineExpr parseNegateExpression(AffineExpr lhs);
+ AffineExpr parseIntegerExpr();
+ AffineExpr parseBareIdExpr();
+
+ AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
+ AffineExpr rhs, SMLoc opLoc);
+ AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs,
+ AffineExpr rhs);
+ AffineExpr parseAffineOperandExpr(AffineExpr lhs);
+ AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp);
+ AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp,
+ SMLoc llhsOpLoc);
+
+ Parser &parser;
+ std::function<AffineExpr(StringRef)> bareIdFallback;
+ AffineDimList &dims;
+ AffineSymbolList &symbols;
+};
+} // end anonymous namespace
+
+/// Create an affine binary high precedence op expression (mul's, div's, mod).
+/// opLoc is the location of the op token to be used to report errors
+/// for non-conforming expressions.
+AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op,
+ AffineExpr lhs, AffineExpr rhs,
+ SMLoc opLoc) {
+ switch (op) {
+ case Mul:
+ if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) {
+ parser.emitError(opLoc,
+ "non-affine expression: at least one of the multiply "
+ "operands has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return lhs * rhs;
+ case FloorDiv:
+ if (!rhs.isSymbolicOrConstant()) {
+ parser.emitError(opLoc,
+ "non-affine expression: right operand of floordiv "
+ "has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return lhs.floorDiv(rhs);
+ case CeilDiv:
+ if (!rhs.isSymbolicOrConstant()) {
+ parser.emitError(opLoc, "non-affine expression: right operand of ceildiv "
+ "has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return lhs.ceilDiv(rhs);
+ case Mod:
+ if (!rhs.isSymbolicOrConstant()) {
+ parser.emitError(opLoc, "non-affine expression: right operand of mod "
+ "has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return lhs % rhs;
+ case HNoOp:
+ llvm_unreachable("can't create affine expression for null high prec op");
+ return nullptr;
+ }
+ llvm_unreachable("Unknown AffineHighPrecOp");
+}
+
+/// Create an affine binary low precedence op expression (add, sub).
+AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op,
+ AffineExpr lhs, AffineExpr rhs) {
+ switch (op) {
+ case AffineLowPrecOp::Add:
+ return lhs + rhs;
+ case AffineLowPrecOp::Sub:
+ return lhs - rhs;
+ case AffineLowPrecOp::LNoOp:
+ llvm_unreachable("can't create affine expression for null low prec op");
+ return nullptr;
+ }
+ llvm_unreachable("Unknown AffineLowPrecOp");
+}
+
+/// Consume this token if it is a lower precedence affine op (there are only
+/// two precedence levels).
+AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
+ switch (parser.curToken.getKind()) {
+ case Token::Kind::plus:
+ parser.consumeToken();
+ return AffineLowPrecOp::Add;
+ case Token::Kind::minus:
+ parser.consumeToken();
+ return AffineLowPrecOp::Sub;
+ default:
+ return AffineLowPrecOp::LNoOp;
+ }
+}
+
+/// Consume this token if it is a higher precedence affine op (there are only
+/// two precedence levels)
+AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
+ switch (parser.curToken.getKind()) {
+ case Token::Kind::star:
+ parser.consumeToken(Token::Kind::star);
+ return Mul;
+ case Token::Kind::kw_floordiv:
+ parser.consumeToken(Token::Kind::kw_floordiv);
+ return FloorDiv;
+ case Token::Kind::kw_ceildiv:
+ parser.consumeToken(Token::Kind::kw_ceildiv);
+ return CeilDiv;
+ case Token::Kind::kw_mod:
+ parser.consumeToken(Token::Kind::kw_mod);
+ return Mod;
+ default:
+ return HNoOp;
+ }
+}
+
+/// Parse a high precedence op expression list: mul, div, and mod are high
+/// precedence binary ops, i.e., parse a
+/// expr_1 op_1 expr_2 op_2 ... expr_n
+/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
+/// All affine binary ops are left associative.
+/// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is
+/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
+/// null. llhsOpLoc is the location of the llhsOp token that will be used to
+/// report an error for non-conforming expressions.
+AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs,
+ AffineHighPrecOp llhsOp,
+ SMLoc llhsOpLoc) {
+ AffineExpr lhs = parseAffineOperandExpr(llhs);
+ if (!lhs)
+ return nullptr;
+
+ // Found an LHS. Parse the remaining expression.
+ auto opLoc = parser.curToken.getLoc();
+ if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
+ if (llhs) {
+ AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc);
+ if (!expr)
+ return nullptr;
+ return parseAffineHighPrecOpExpr(expr, op, opLoc);
+ }
+ // No LLHS, get RHS
+ return parseAffineHighPrecOpExpr(lhs, op, opLoc);
+ }
+
+ // This is the last operand in this expression.
+ if (llhs)
+ return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc);
+
+ // No llhs, 'lhs' itself is the expression.
+ return lhs;
+}
+
+/// Parse an affine expression inside parentheses.
+///
+/// affine-expr ::= `(` affine-expr `)`
+AffineExpr AffineParser::parseParentheticalExpr() {
+ if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
+ return nullptr;
+ if (parser.curToken.is(Token::Kind::r_paren))
+ return (parser.emitError("no expression inside parentheses"), nullptr);
+
+ auto expr = parseAffineExpr();
+ if (!expr)
+ return nullptr;
+ if (failed(parser.parseToken(Token::Kind::r_paren, "expected ')'")))
+ return nullptr;
+
+ return expr;
+}
+
+/// Parse the negation expression.
+///
+/// affine-expr ::= `-` affine-expr
+AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
+ if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")))
+ return nullptr;
+
+ AffineExpr operand = parseAffineOperandExpr(lhs);
+ // Since negation has the highest precedence of all ops (including high
+ // precedence ops) but lower than parentheses, we are only going to use
+ // parseAffineOperandExpr instead of parseAffineExpr here.
+ if (!operand)
+ // Extra error message although parseAffineOperandExpr would have
+ // complained. Leads to a better diagnostic.
+ return (parser.emitError("missing operand of negation"), nullptr);
+ return (-1) * operand;
+}
+
+/// Parse a bare id that may appear in an affine expression.
+///
+/// affine-expr ::= bare-id
+AffineExpr AffineParser::parseBareIdExpr() {
+ if (parser.curToken.isNot(Token::Kind::id))
+ return (parser.emitError("expected id"), nullptr);
+
+ StringRef sRef = parser.curToken.getSpelling();
+ for (auto &list : {dims, symbols}) {
+ for (auto entry : list) {
+ if (entry.first == sRef) {
+ parser.consumeToken(Token::Kind::id);
+ return entry.second;
+ }
+ }
+ }
+
+ // Not found, check fallback path.
+ AffineExpr expr = bareIdFallback(sRef);
+ if (expr) {
+ parser.consumeToken(Token::Kind::id);
+ return expr;
+ }
+
+ return (parser.emitError("use of undeclared id"), nullptr);
+}
+
+/// Parse a positive integral constant appearing in an affine expression.
+///
+/// affine-expr ::= integer-literal
+AffineExpr AffineParser::parseIntegerExpr() {
+ auto val = parser.curToken.getUInt64IntegerValue();
+ if (!val.hasValue() || (int64_t)val.getValue() < 0)
+ return (parser.emitError("constant too large for index"), nullptr);
+
+ parser.consumeToken(Token::Kind::integer);
+ return getAffineConstantExpr((int64_t)val.getValue(), parser.context);
+}
+
+/// Parses an expression that can be a valid operand of an affine expression.
+/// lhs: if non-null, lhs is an affine expression that is the lhs of a binary
+/// operator, the rhs of which is being parsed. This is used to determine
+/// whether an error should be emitted for a missing right operand.
+// Eg: for an expression without parentheses (like i + j + k + l), each
+// of the four identifiers is an operand. For i + j*k + l, j*k is not an
+// operand expression, it's an op expression and will be parsed via
+// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
+// -l are valid operands that will be parsed by this function.
+AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
+ switch (parser.curToken.getKind()) {
+ case Token::Kind::id:
+ return parseBareIdExpr();
+ case Token::Kind::integer:
+ return parseIntegerExpr();
+ case Token::Kind::l_paren:
+ return parseParentheticalExpr();
+ case Token::Kind::minus:
+ return parseNegateExpression(lhs);
+ case Token::Kind::kw_ceildiv:
+ case Token::Kind::kw_floordiv:
+ case Token::Kind::kw_mod:
+ case Token::Kind::plus:
+ case Token::Kind::star:
+ if (lhs)
+ parser.emitError("missing right operand of binary operator");
+ else
+ parser.emitError("missing left operand of binary operator");
+ return nullptr;
+ default:
+ if (lhs)
+ parser.emitError("missing right operand of binary operator");
+ else
+ parser.emitError("expected affine expression");
+ return nullptr;
+ }
+}
+
+/// Parse affine expressions that are bare-id's, integer constants,
+/// parenthetical affine expressions, and affine op expressions that are a
+/// composition of those.
+///
+/// All binary op's associate from left to right.
+///
+/// {add, sub} have lower precedence than {mul, div, and mod}.
+///
+/// Add, sub'are themselves at the same precedence level. Mul, floordiv,
+/// ceildiv, and mod are at the same higher precedence level. Negation has
+/// higher precedence than any binary op.
+///
+/// llhs: the affine expression appearing on the left of the one being parsed.
+/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
+/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
+/// if llhs is non-null; otherwise lhs is returned. This is to deal with left
+/// associativity.
+///
+/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
+/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
+/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
+AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
+ AffineLowPrecOp llhsOp) {
+ AffineExpr lhs;
+ if (!(lhs = parseAffineOperandExpr(llhs)))
+ return nullptr;
+
+ // Found an LHS. Deal with the ops.
+ if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
+ if (llhs) {
+ AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs);
+ return parseAffineLowPrecOpExpr(sum, lOp);
+ }
+ // No LLHS, get RHS and form the expression.
+ return parseAffineLowPrecOpExpr(lhs, lOp);
+ }
+ auto opLoc = parser.curToken.getLoc();
+ if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
+ // We have a higher precedence op here. Get the rhs operand for the llhs
+ // through parseAffineHighPrecOpExpr.
+ AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
+ if (!highRes)
+ return nullptr;
+
+ // If llhs is null, the product forms the first operand of the yet to be
+ // found expression. If non-null, the op to associate with llhs is llhsOp.
+ AffineExpr expr =
+ llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes;
+
+ // Recurse for subsequent low prec op's after the affine high prec op
+ // expression.
+ if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
+ return parseAffineLowPrecOpExpr(expr, nextOp);
+ return expr;
+ }
+ // Last operand in the expression list.
+ if (llhs)
+ return getAffineBinaryOpExpr(llhsOp, llhs, lhs);
+ // No llhs, 'lhs' itself is the expression.
+ return lhs;
+}
+
+/// Parse an affine expression.
+/// affine-expr ::= `(` affine-expr `)`
+/// | `-` affine-expr
+/// | affine-expr `+` affine-expr
+/// | affine-expr `-` affine-expr
+/// | affine-expr `*` affine-expr
+/// | affine-expr `floordiv` affine-expr
+/// | affine-expr `ceildiv` affine-expr
+/// | affine-expr `mod` affine-expr
+/// | bare-id
+/// | integer-literal
+///
+/// Additional conditions are checked depending on the production. For eg.,
+/// one of the operands for `*` has to be either constant/symbolic; the second
+/// operand for floordiv, ceildiv, and mod has to be a positive integer.
+AffineExpr AffineParser::parseAffineExpr() {
+ return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
+}
+
+SmallVector<AffineExpr, 4> AffineParser::parseAffineExprs(Token::Kind lDelim,
+ Token::Kind rDelim) {
+ parser.parseToken(lDelim, "expected lDelim at start of affine expr list");
+
+ SmallVector<AffineExpr, 4> exprs;
+ auto parseElt = [&]() -> LogicalResult {
+ auto elt = parseAffineExpr();
+ exprs.push_back(elt);
+ return elt ? success() : failure();
+ };
+
+ if (failed(parser.parseCommaSeparatedListUntil(rDelim, parseElt,
+ /*allowEmptyList=*/true)))
+ llvm_unreachable("Failed AffineExpr parsing");
+
+ return exprs;
+}
+
+//===----------------------------------------------------------------------===//
+// TC parsing.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Base class for expressions involved in TC parsing.
+struct Expression {
+ enum class Kind {
+ Uninitialized = 0,
+ TensorExpr = 1,
+ TensorUse = 2,
+ };
+
+ explicit Expression(Kind k = Kind::Uninitialized) : kind(k) {}
+ virtual ~Expression() = 0;
+
+ bool operator==(const Expression &e) const;
+ operator bool() const { return kind != Kind::Uninitialized; }
+
+ Kind kind;
+};
+
+/// Encodes a tensor use of the form:
+///
+/// affine-expr-list ::= affine-expr (`,` affine-expr)*
+/// tensor-use ::= bare-id `(` `)`
+/// | bare-id `(` affine-expr-list `)`
+///
+/// The affine-expr-list is stored as an AffineMap.
+struct TensorUse : public Expression {
+ TensorUse() : TensorUse("", AffineMap()) {}
+ TensorUse(StringRef name, AffineMap map)
+ : Expression(Kind::TensorUse), tensorId(name), indexingMap(map) {}
+ TensorUse(const TensorUse &use) = default;
+
+ static bool classof(const Expression *e) {
+ return e->kind == Kind::TensorUse;
+ }
+
+ bool operator==(const TensorUse &other) const {
+ return tensorId == other.tensorId && indexingMap == other.indexingMap;
+ }
+
+ /// Visitation function. Performs preorder or postorder traversal depending on
+ /// `PreOrder` and applies `callback` on each node.
+ template <typename Lambda, bool PreOrder>
+ void visit(Lambda callback) const;
+
+ StringRef tensorId;
+ AffineMap indexingMap;
+};
+
+/// Encodes a tensor expression of the form:
+///
+/// op-spec ::= bare-id `<` reduction-dims-list `>`
+/// | bare-id
+/// op-arg ::= tensor-expr
+/// | tensor-use
+/// op-arg-list ::= op-arg (`,` op-arg)*
+/// tensor-expr ::= op-spec `(` op-arg-list `)`
+///
+/// Underlying op-arg are stored by unique_ptr to base class.
+struct TensorExpr : public Expression {
+ TensorExpr(StringRef name,
+ SmallVectorImpl<std::unique_ptr<Expression>> &&exprs,
+ ArrayRef<unsigned> reductionDims)
+ : Expression(Kind::TensorExpr), opId(name), expressions(std::move(exprs)),
+ reductionDimensions(reductionDims.begin(), reductionDims.end()) {}
+
+ static bool classof(const Expression *e) {
+ return e->kind == Kind::TensorExpr;
+ }
+
+ bool operator==(const TensorExpr &other) const {
+ if (opId != other.opId)
+ return false;
+ if (expressions.size() != other.expressions.size())
+ return false;
+ for (unsigned i = 0, e = expressions.size(); i < e; ++i)
+ if (*expressions[i] != *other.expressions[i])
+ return false;
+ for (unsigned i = 0, e = reductionDimensions.size(); i < e; ++i)
+ if (reductionDimensions[i] != other.reductionDimensions[i])
+ return false;
+ return true;
+ }
+
+ /// Visitation function. Performs preorder or postorder traversal depending on
+ /// `PreOrder` and applies `callback` on each node.
+ template <typename Lambda, bool PreOrder>
+ void visit(Lambda callback) const;
+
+ StringRef opId;
+ SmallVector<std::unique_ptr<Expression>, 4> expressions;
+ SetVector<unsigned> reductionDimensions;
+};
+
+Expression::~Expression() {}
+
+bool Expression::operator==(const Expression &e) const {
+ if (this->kind != e.kind)
+ return false;
+ if (e.kind == Expression::Kind::TensorUse)
+ return static_cast<const TensorUse &>(*this) ==
+ static_cast<const TensorUse &>(e);
+ if (e.kind == Expression::Kind::TensorExpr)
+ return static_cast<const TensorExpr &>(*this) ==
+ static_cast<const TensorExpr &>(e);
+ llvm_unreachable("Unexpected case");
+}
+
+/// This is a specialized parser for a TCDef.
+/// This maintains the dims it finds in an eager fashion.
+class TCParser {
+ enum class EagerDiscoveryMode { None = 0, Symbols, Dimensions };
+
+public:
+ explicit TCParser(Parser &p);
+
+ /// Uses the AffineParser to parse the affine exprs used in a tensor
+ /// definition. If `discoveryMode` is set to Symbols (resp. Dimensions), new
+ /// symbols (resp. dimensions) are added eagerly. Otherwise, an error is
+ /// emitted on new identifiers.
+ SmallVector<AffineExpr, 4>
+ parseAffineExprs(EagerDiscoveryMode discoveryMode, AffineDimList &dims,
+ Token::Kind lDelim = Token::Kind::l_paren,
+ Token::Kind rDelim = Token::Kind::r_paren);
+
+ /// Parse the information for a tensor def.
+ /// All the affine-expr must be dimensionless (i.e. contain only expressions
+ /// involving symbols and constants), but can otherwise contain arbitrary
+ /// affine expressions.
+ LogicalResult parseTensorDef(bool isOutput);
+
+ /// Parses a tensor use.
+ struct ComprehensionParsingState {
+ AffineDimList dims;
+ SmallVector<std::unique_ptr<Expression>, 4> expressions;
+ llvm::DenseMap<TensorUse, unsigned> orderedTensorArgs;
+ };
+ LogicalResult parseTensorUse(TensorUse &result,
+ ComprehensionParsingState &state);
+
+ /// Parses a tensor expression.
+ LogicalResult parseExpression(TensorUse currentDefinition,
+ std::unique_ptr<Expression> &result,
+ ComprehensionParsingState &state);
+
+ /// Parse a single comprehension.
+ LogicalResult parseOneComprehension(StringRef cppOpName,
+ StringRef linalgOpName,
+ ComprehensionParsingState &state);
+
+ /// Parse and print the information for a TC def.
+ /// When `gen-ods-decl` is used, this prints the ODS declaration for the TC.
+ /// When `gen-impl` is used, this prints the C++ implementation for the extra
+ /// methods defined in ODS (referenceIterators, referenceIndexingMaps and
+ /// regionBuilder).
+ LogicalResult parseAndEmitTCDef(llvm::raw_ostream &os);
+
+ /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
+ void printODS(llvm::raw_ostream &os, StringRef cppOpName,
+ StringRef linalgOpName);
+
+ /// Print the C++ StructuredOpsInterface impl of `referenceIterators`.
+ void printReferenceIterators(llvm::raw_ostream &os, StringRef opId,
+ ComprehensionParsingState &state);
+
+ /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
+ void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId,
+ ComprehensionParsingState &state);
+
+ /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
+ void printRegionBuilder(llvm::raw_ostream &os, StringRef opId,
+ ComprehensionParsingState &state);
+
+private:
+ //===--------------------------------------------------------------------===//
+ // Internal bookkeeping of tensors.
+ //===--------------------------------------------------------------------===//
+ struct RegisteredTensor {
+ StringRef type;
+ AffineMap shape;
+ bool isOutput;
+ AffineMap indexingMap;
+ unsigned index;
+ };
+
+ //===--------------------------------------------------------------------===//
+ // Per-TC def state.
+ //===--------------------------------------------------------------------===//
+ /// Symbols are per TC def.
+ AffineSymbolList symbols;
+ /// Tensors are per TC def.
+ llvm::StringMap<RegisteredTensor> registeredTensors;
+ unsigned nextRegisteredTensorIndex;
+
+ Parser &parser;
+};
+} // namespace
+
+namespace llvm {
+
+template <>
+struct DenseMapInfo<TensorUse> {
+ static TensorUse getEmptyKey() { return TensorUse("", AffineMap()); }
+ static TensorUse getTombstoneKey() {
+ return TensorUse(DenseMapInfo<StringRef>::getTombstoneKey(),
+ DenseMapInfo<AffineMap>::getTombstoneKey());
+ }
+ static unsigned getHashValue(const TensorUse &val) {
+ return ::llvm::hash_value(val.tensorId); // don't care about collisions.
+ }
+ static bool isEqual(const TensorUse &LHS, const TensorUse &RHS) {
+ return LHS == RHS;
+ }
+};
+
+} // namespace llvm
+
+//===----------------------------------------------------------------------===//
+// Visitation functions.
+//===----------------------------------------------------------------------===//
+
+template <typename Lambda, bool PreOrder>
+void visit(const Expression &expr, Lambda callback) {
+ switch (expr.kind) {
+ default:
+ llvm_unreachable("Unexpected kind");
+ case Expression::Kind::TensorExpr:
+ static_cast<const TensorExpr &>(expr).visit<Lambda, PreOrder>(callback);
+ break;
+ case Expression::Kind::TensorUse:
+ static_cast<const TensorUse &>(expr).visit<Lambda, PreOrder>(callback);
+ break;
+ }
+}
+
+template <typename Lambda>
+void visitPreorder(const Expression &expr, Lambda callback) {
+ visit<Lambda, false>(expr, callback);
+}
+
+template <typename Lambda>
+void visitPostorder(Expression &expr, Lambda callback) {
+ visit<Lambda, true>(expr, callback);
+}
+
+template <typename Lambda, bool PreOrder>
+void TensorExpr::visit(Lambda callback) const {
+ if (!PreOrder)
+ callback(*this);
+ for (auto &e : expressions)
+ ::visit<Lambda, PreOrder>(*e, callback);
+ if (PreOrder)
+ callback(*this);
+}
+
+template <typename Lambda, bool PreOrder>
+void TensorUse::visit(Lambda callback) const {
+ callback(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// TC parsing functions.
+//===----------------------------------------------------------------------===//
+TCParser::TCParser(Parser &p)
+ : symbols(), registeredTensors(), nextRegisteredTensorIndex(0), parser(p) {}
+
+/// Uses the AffineParser to parse the affine exprs used in a tensor
+/// definition. All identifiers are interpreted as symbols, new symbols are
+/// added eagerly.
+SmallVector<AffineExpr, 4>
+TCParser::parseAffineExprs(EagerDiscoveryMode discoveryMode,
+ AffineDimList &dims, Token::Kind lDelim,
+ Token::Kind rDelim) {
+ AffineParser affineParser(
+ parser,
+ [&](StringRef sRef) {
+ AffineExpr expr;
+ if (discoveryMode == EagerDiscoveryMode::Symbols) {
+ expr = getAffineSymbolExpr(symbols.size(), parser.context);
+ symbols.emplace_back(sRef, expr);
+ } else if (discoveryMode == EagerDiscoveryMode::Dimensions) {
+ expr = getAffineDimExpr(dims.size(), parser.context);
+ dims.emplace_back(sRef, expr);
+ }
+ return expr;
+ },
+ dims, symbols);
+ return affineParser.parseAffineExprs(lDelim, rDelim);
+}
+
+/// Parse the information for a tensor def of the form:
+///
+/// affine-expr-list ::= affine-expr (`,` affine-expr )*
+/// tensor-typedef ::= type `(` `)`
+/// | type `(` affine-expr-list `)`
+/// tensor-def ::= bare-id `:` tensor-typedef
+LogicalResult TCParser::parseTensorDef(bool isOutput) {
+ StringRef tensorId = parser.curToken.getSpelling();
+ if (failed(parser.parseToken(Token::Kind::id, "expected an id")) ||
+ failed(parser.parseToken(Token::Kind::colon, "expected colon")))
+ return failure();
+
+ StringRef tensorType = parser.curToken.getSpelling();
+ if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
+ return failure();
+
+ AffineDimList emptyDims;
+ auto exprs = parseAffineExprs(EagerDiscoveryMode::Symbols, emptyDims);
+ assert(emptyDims.empty() && "Unexpected dimension in tensor def");
+ AffineMap map =
+ AffineMap::get(/*dimCount=*/0, symbols.size(), exprs, parser.context);
+
+ auto iterBoolPair = registeredTensors.try_emplace(
+ tensorId, RegisteredTensor{tensorType, map, isOutput, AffineMap(),
+ nextRegisteredTensorIndex++});
+ assert(iterBoolPair.second && "Could not emplace tensor registration");
+ LLVM_DEBUG(llvm::dbgs() << "Recorded: " << tensorId << " "
+ << "with typeString: " << tensorType << " "
+ << "and shape: " << map << "\n");
+
+ return success();
+}
+
+/// Parses a tensor use of the form:
+///
+/// affine-expr-list ::= affine-expr (`,` affine-expr)*
+/// tensor-use ::= bare-id `(` `)`
+/// | bare-id `(` affine-expr-list `)`
+LogicalResult TCParser::parseTensorUse(TensorUse &result,
+ ComprehensionParsingState &state) {
+ StringRef tensorId = parser.curToken.getSpelling();
+ if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
+ return failure();
+
+ auto exprs = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims);
+ AffineMap map =
+ AffineMap::get(state.dims.size(), symbols.size(), exprs, parser.context);
+ LLVM_DEBUG(llvm::dbgs() << "Use of tensor: " << tensorId << " map: " << map
+ << "\n");
+
+ result = TensorUse(tensorId, map);
+ return success();
+}
+
+/// Parses a tensor expression of the form:
+///
+/// op-spec ::= bare-id `<` reduction-dims-list `>`
+/// | bare-id
+/// op-arg ::= tensor-expr
+/// | tensor-use
+/// op-arg-list ::= op-arg (`,` op-arg)*
+/// tensor-expr ::= op-spec `(` op-arg-list `)`
+LogicalResult TCParser::parseExpression(TensorUse currentDefinition,
+ std::unique_ptr<Expression> &result,
+ ComprehensionParsingState &state) {
+ StringRef opOrTensor = parser.curToken.getSpelling();
+ if (registeredTensors.count(opOrTensor) > 0) {
+ TensorUse use;
+ auto res = parseTensorUse(use, state);
+ if (failed(res))
+ return res;
+ result = std::make_unique<TensorUse>(use);
+ return success();
+ }
+
+ if (failed(parser.parseToken(Token::Kind::id, "expected an operation")))
+ return failure();
+
+ // This is an op.
+ SmallVector<unsigned, 4> reductionDims;
+ SmallVector<std::unique_ptr<Expression>, 4> expressions;
+
+ // Check if it has a reduction set, discover dimensions eagerly.
+ if (parser.curToken.is(Token::Kind::lt)) {
+ auto iters = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims,
+ Token::Kind::lt, Token::Kind::gt);
+ for (auto iter : iters)
+ reductionDims.push_back(iter.cast<AffineDimExpr>().getPosition());
+ }
+
+ // If this op is a reduction, it's first argument is the `currentDefinition`
+ // tensor use.
+ if (!reductionDims.empty())
+ expressions.push_back(std::make_unique<TensorUse>(currentDefinition));
+ LLVM_DEBUG(llvm::dbgs() << "op: " << opOrTensor << "\n");
+
+ auto parseExpr = [&]() -> LogicalResult {
+ std::unique_ptr<Expression> e;
+ if (failed(parseExpression(currentDefinition, e, state)))
+ return failure();
+ expressions.push_back(std::move(e));
+ return success();
+ };
+ if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")) ||
+ failed(parser.parseCommaSeparatedListUntil(
+ Token::Kind::r_paren, parseExpr, /*allowEmptyList=*/true)))
+ return failure();
+
+ result = std::make_unique<TensorExpr>(opOrTensor, std::move(expressions),
+ reductionDims);
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Parse and Emit functions.
+//===----------------------------------------------------------------------===//
+
+/// Parse the information for a single comprehension.
+///
+/// tensor-def-list ::= tensor-def (`,` tensor-def)*
+/// tensor-expr-list ::= tensor-expr (`,` tensor-expr)*
+/// comprehension ::= tensor-def-list `=` tensor-expr-list `;`
+LogicalResult
+TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
+ ComprehensionParsingState &state) {
+ // 1. Parse LHS of `=`, these become the definitions that appear as the output
+ // tensors or read/write buffers.
+ SmallVector<TensorUse, 4> definitions;
+ auto parseUse = [&]() -> LogicalResult {
+ TensorUse use;
+ if (failed(parseTensorUse(use, state)))
+ return failure();
+ definitions.push_back(use);
+ return success();
+ };
+ if (failed(parser.parseCommaSeparatedListUntil(Token::Kind::equal, parseUse,
+ /*allowEmptyList=*/true)))
+ return failure();
+
+ // 2. Parse RHS of `=`, this becomes the expressions from which we emit
+ // computations.
+ unsigned idx = 0;
+ auto parseExpr = [&]() -> LogicalResult {
+ std::unique_ptr<Expression> expr;
+ if (idx >= definitions.size()) {
+ parser.emitError("Fewer LHS definitions than RHS expressions");
+ return failure();
+ }
+ if (failed(parseExpression(definitions[idx++], expr, state)))
+ return failure();
+ state.expressions.push_back(std::move(expr));
+ return success();
+ };
+ if (failed(parser.parseCommaSeparatedListUntil(
+ Token::Kind::semicolon, parseExpr, /*allowEmptyList=*/true)))
+ return failure();
+ if (idx != definitions.size()) {
+ parser.emitError("Fewer RHS expressions than LHS definitions");
+ return failure();
+ }
+
+ // 3. Postprocess.
+ // 3.a. Normalize all maps to the proper state.dims and symbols counts.
+ SmallVector<TensorUse, 4> allUses;
+ allUses.reserve(registeredTensors.size());
+ for (auto &def : definitions)
+ allUses.push_back(def);
+ for (auto &pExpr : state.expressions)
+ visitPostorder(*pExpr, [&](const Expression &e) {
+ if (auto *use = dyn_cast<TensorUse>(&e))
+ allUses.push_back(*use);
+ });
+ for (auto &use : allUses)
+ use.indexingMap =
+ AffineMap::get(state.dims.size(), symbols.size(),
+ use.indexingMap.getResults(), parser.context);
+
+ // 3.b. Traverse definitions
+ llvm::DenseSet<StringRef> seenDefs;
+ for (auto &def : definitions) {
+ if (seenDefs.count(def.tensorId) > 0) {
+ parser.emitError("Unexpected multi-write to a single tensor");
+ return failure();
+ }
+ seenDefs.insert(def.tensorId);
+ auto tensorIter = registeredTensors.find(def.tensorId);
+ assert(tensorIter != registeredTensors.end() && "unregistered tensor");
+ auto &tensor = tensorIter->getValue();
+ tensor.indexingMap = def.indexingMap;
+ state.orderedTensorArgs[def] = tensor.index;
+ }
+
+ bool failed = false;
+ for (auto &pExpr : state.expressions)
+ visitPostorder(*pExpr, [&](const Expression &e) {
+ auto *pUse = dyn_cast<TensorUse>(&e);
+ if (failed || !pUse)
+ return;
+ auto &use = *pUse;
+ LLVM_DEBUG(llvm::dbgs()
+ << "\nuse: " << use.tensorId << " map: " << use.indexingMap);
+ auto tensorIter = registeredTensors.find(use.tensorId);
+ assert(tensorIter != registeredTensors.end() && "unregistered tensor");
+ auto &tensor = tensorIter->getValue();
+ if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0) {
+ LLVM_DEBUG(llvm::dbgs() << "\nexisting: " << tensor.indexingMap);
+ parser.emitError(
+ "Unexpected multi-read of a tensor with
diff erent accesses");
+ failed = true;
+ return;
+ }
+ seenDefs.insert(use.tensorId);
+ tensor.indexingMap = use.indexingMap;
+ state.orderedTensorArgs[use] = tensor.index;
+ });
+ if (failed)
+ return failure();
+
+ return success();
+}
+
+/// Parse and print the information for a TC def.
+///
+/// tensor-def-list ::= tensor-def (`,` tensor-def )*
+///
+/// comprehension-list ::= comprehension comprehension*
+///
+/// tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)`
+/// `{` comprehension-list `}`
+///
+/// All the affine-expr in a `tensor-typedef` must be dimensionless (i.e.
+/// contain only expressions involving symbols and constants), but can
+/// otherwise contain arbitrary affine expressions.
+LogicalResult TCParser::parseAndEmitTCDef(llvm::raw_ostream &os) {
+ if (failed(parser.parseToken(Token::Kind::kw_def,
+ "expected 'def' to define a TC")))
+ return failure();
+
+ StringRef tcName = parser.curToken.getSpelling();
+ LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing tc: " << tcName << "\n");
+ if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
+ failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
+ return failure();
+
+ auto parseInputDef = [&]() -> LogicalResult {
+ return parseTensorDef(/*isOutput=*/false);
+ };
+ if (failed(parser.parseCommaSeparatedListUntil(
+ Token::Kind::r_paren, parseInputDef, /*allowEmptyList=*/false)))
+ return failure();
+
+ if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")) ||
+ failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
+ failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
+ return failure();
+ auto parseOutputDef = [&]() -> LogicalResult {
+ return parseTensorDef(/*isOutput=*/true);
+ };
+ if (failed(parser.parseCommaSeparatedListUntil(
+ Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false)))
+ return failure();
+
+ // Since we don't declare symbols separately, we discover them eagerly: each
+ // newly encountered id in a tensor shape expression is treated as a new
+ // symbolic. At this point, all tensors have been parsed and all the symbols
+ // that could be discovered eagerly are now known. Resize all AffineMaps to
+ // normalize the number of eagerly discovered symbols.
+ for (auto &tensor : registeredTensors) {
+ auto &map = tensor.getValue().shape;
+ map = AffineMap::get(/*dimCount=*/0, symbols.size(), map.getResults(),
+ parser.context);
+ }
+
+ if (failed(parser.parseToken(Token::Kind::l_brace, "expected '{'")))
+ return failure();
+
+ SmallVector<ComprehensionParsingState, 4> perComprehensionStates;
+ while (parser.curToken.isNot(Token::Kind::r_brace)) {
+ perComprehensionStates.push_back(ComprehensionParsingState());
+ if (failed(parseOneComprehension(tcName, tcName,
+ perComprehensionStates.back())))
+ return failure();
+ };
+ parser.parseToken(Token::Kind::r_brace, "expected '}'");
+
+ // Print.
+ auto nComprehensions = perComprehensionStates.size();
+ if (nComprehensions != 1) {
+ parser.emitError("only 1 comprehension supported for now, got: " +
+ llvm::Twine(nComprehensions));
+ return failure();
+ }
+ if (genODSDecl) {
+ printODS(os, tcName, tcName);
+ os << "\n";
+ }
+ if (genODSImpl) {
+ auto &state = perComprehensionStates.back();
+ std::string extraMethods;
+ llvm::raw_string_ostream ss(extraMethods);
+ printReferenceIterators(ss, tcName, state);
+ printReferenceIndexingMaps(ss, tcName, state);
+ printRegionBuilder(ss, tcName, state);
+ ss.flush();
+ os << extraMethods << "\n";
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing functions
+//===----------------------------------------------------------------------===//
+
+/// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
+void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
+ StringRef linalgOpName) {
+ const char *header = R"FMT( def {0}Op : LinalgNamedStructured_Op<"{1}", [
+ NInputs<{2}>,
+ NOutputs<{3}>,
+ NamedStructuredOpTraits]> {
+ let arguments = (ins Variadic<LinalgOperand>:$views);
+ let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
+ let extraClassDeclaration = [{{
+ llvm::Optional<SmallVector<StringRef, 8>> referenceIterators();
+ llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps();
+ void regionBuilder(ArrayRef<BlockArgument> args);
+ }];
+ let hasFolder = 1;
+ })FMT";
+
+ unsigned nInputs = 0, nOutputs = 0;
+ for (auto &t : registeredTensors) {
+ if (t.getValue().isOutput)
+ nOutputs++;
+ else
+ nInputs++;
+ }
+
+ os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs);
+}
+
+/// Print the C++ StructuredOpsInterface impl of `referenceIterators`.
+void TCParser::printReferenceIterators(llvm::raw_ostream &os, StringRef opId,
+ ComprehensionParsingState &state) {
+ const char *referenceReferenceIteratorsFmt =
+ R"FMT(
+ llvm::Optional<SmallVector<StringRef, 8>> {0}::referenceIterators() {
+ return SmallVector<StringRef, 8>{{ {1} };
+ })FMT";
+
+ std::string iteratorsStr;
+ llvm::raw_string_ostream ss(iteratorsStr);
+ unsigned pos = 0;
+ interleaveComma(state.dims, ss, [&](std::pair<StringRef, AffineExpr> p) {
+ bool reduction = false;
+ for (auto &expr : state.expressions) {
+ visitPostorder(*expr, [&](const Expression &e) {
+ if (auto *pTensorExpr = dyn_cast<TensorExpr>(&e)) {
+ if (pTensorExpr->reductionDimensions.count(pos) > 0)
+ reduction = true;
+ }
+ });
+ if (reduction)
+ break;
+ }
+ ss << (reduction ? "getReductionIteratorTypeName()"
+ : "getParallelIteratorTypeName()");
+ pos++;
+ });
+ ss.flush();
+
+ os << llvm::formatv(referenceReferenceIteratorsFmt, opId, iteratorsStr);
+}
+
+/// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
+void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId,
+ ComprehensionParsingState &state) {
+ const char *referenceIndexingMapsFmt =
+ R"FMT(
+ llvm::Optional<SmallVector<AffineMap, 8>> {0}::referenceIndexingMaps() {
+ MLIRContext *context = getContext();
+ AffineExpr {1};
+ bindDims(context, {1});
+ return SmallVector<AffineMap, 8>{{ {2} };
+ })FMT";
+
+ std::string dimsStr;
+ llvm::raw_string_ostream ss(dimsStr);
+ interleaveComma(state.dims, ss,
+ [&](std::pair<StringRef, AffineExpr> p) { ss << p.second; });
+ ss.flush();
+
+ std::string mapsStr;
+ llvm::raw_string_ostream mapsStringStream(mapsStr);
+ SmallVector<TensorUse, 4> orderedUses(state.orderedTensorArgs.size());
+ for (auto it : state.orderedTensorArgs)
+ orderedUses[it.second] = it.first;
+ interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) {
+ assert(u.indexingMap);
+ const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1})";
+ if (u.indexingMap.isEmpty()) {
+ mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), "context");
+ return;
+ }
+
+ std::string exprsStr;
+ llvm::raw_string_ostream exprsStringStream(exprsStr);
+ exprsStringStream << "{";
+ interleaveComma(u.indexingMap.getResults(), exprsStringStream);
+ exprsStringStream << "}";
+ exprsStringStream.flush();
+
+ mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), exprsStr);
+ });
+ mapsStringStream.flush();
+
+ os << llvm::formatv(referenceIndexingMapsFmt, opId, dimsStr, mapsStr);
+}
+
+/// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
+void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef opId,
+ ComprehensionParsingState &state) {
+ unsigned count = state.orderedTensorArgs.size();
+ llvm::DenseMap<const TensorExpr *, unsigned> subExprsMap;
+ std::function<void(llvm::raw_ostream & os, const Expression &)> printExpr;
+ printExpr = [&](llvm::raw_ostream &os, const Expression &e) -> void {
+ if (auto *pUse = dyn_cast<TensorUse>(&e)) {
+ os << "_" << state.orderedTensorArgs.find(*pUse)->second;
+ return;
+ }
+ auto *pTensorExpr = cast<TensorExpr>(&e);
+ if (subExprsMap.count(pTensorExpr) > 0) {
+ os << "_" << subExprsMap[pTensorExpr];
+ } else {
+ std::string subExprs;
+ llvm::raw_string_ostream subExprsStringStream(subExprs);
+ interleaveComma(pTensorExpr->expressions, subExprsStringStream,
+ [&](const std::unique_ptr<Expression> &e) {
+ printExpr(subExprsStringStream, *e);
+ });
+ subExprsStringStream.flush();
+ const char *tensorExprFmt = "\n ValueHandle _{0} = {1}({2});";
+ os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->opId, subExprs);
+ subExprsMap[pTensorExpr] = count;
+ }
+ };
+
+ const char *regionBuilderFmt = R"FMT(
+ void {0}::regionBuilder(ArrayRef<BlockArgument> args) {
+ using namespace edsc;
+ using namespace intrinsics;
+ ValueHandle {1};
+ {2}
+ (linalg_yield(ValueRange{ {3} }));
+ })FMT";
+
+ unsigned idx = 0;
+ std::string valueHandleStr;
+ llvm::raw_string_ostream valueHandleStringStream(valueHandleStr);
+ interleaveComma(state.orderedTensorArgs, valueHandleStringStream, [&](auto) {
+ valueHandleStringStream << "_" << idx << "(args[" << idx << "])";
+ idx++;
+ });
+
+ std::string expressionsStr;
+ llvm::raw_string_ostream expressionStringStream(expressionsStr);
+ for (auto &expr : state.expressions)
+ visitPostorder(*expr, [&](const Expression &e) {
+ if (e.kind == Expression::Kind::TensorExpr)
+ printExpr(expressionStringStream, e);
+ });
+
+ std::string yieldStr;
+ llvm::raw_string_ostream yieldStringStream(yieldStr);
+ interleaveComma(state.expressions, yieldStringStream,
+ [&](const std::unique_ptr<Expression> &e) {
+ printExpr(yieldStringStream, *e);
+ });
+
+ valueHandleStringStream.flush();
+ expressionStringStream.flush();
+ yieldStringStream.flush();
+
+ os << llvm::formatv(regionBuilderFmt, opId, valueHandleStr, expressionsStr,
+ yieldStr);
+}
+
+/// Iterate over each Tensor Comprehension def.
+LogicalResult parseAndEmitAllTensorComprehensions(llvm::raw_ostream &os,
+ Parser &parser) {
+ while (parser.curToken.getKind() != Token::Kind::eof) {
+ TCParser tcParser(parser);
+ if (failed(tcParser.parseAndEmitTCDef(os)))
+ return failure();
+ }
+ return success();
+}
+
+int main(int argc, char **argv) {
+ llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen");
+
+ // Set up the input file.
+ std::string errorMessage;
+ std::unique_ptr<llvm::MemoryBuffer> file =
+ mlir::openInputFile(inputFilename, &errorMessage);
+ if (!file) {
+ llvm::errs() << errorMessage << "\n";
+ return 1;
+ }
+
+ std::unique_ptr<llvm::ToolOutputFile> output =
+ openOutputFile(outputFilename, &errorMessage);
+ if (!output) {
+ llvm::errs() << errorMessage << "\n";
+ exit(1);
+ }
+
+ // Include the proper Linalg header for end-to-end tblgen testing without
+ // resorting to non-portable shgell manipulations.
+ if (testEmitIncludeTdHeader)
+ output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
+
+ MLIRContext context;
+ llvm::SourceMgr mgr;
+ mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
+ Parser parser(mgr, &context);
+ parseAndEmitAllTensorComprehensions(output->os(), parser);
+ output->keep();
+
+ return 0;
+}
More information about the Mlir-commits
mailing list