[Mlir-commits] [mlir] e6f2f17 - [mlir][Linalg] Refactor StructuredOpInterface - NFC
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Sep 11 04:53:50 PDT 2020
Author: Nicolas Vasilache
Date: 2020-09-11T07:53:12-04:00
New Revision: e6f2f17f05a1248b069ba830c4afffd61ee2f297
URL: https://github.com/llvm/llvm-project/commit/e6f2f17f05a1248b069ba830c4afffd61ee2f297
DIFF: https://github.com/llvm/llvm-project/commit/e6f2f17f05a1248b069ba830c4afffd61ee2f297.diff
LOG: [mlir][Linalg] Refactor StructuredOpInterface - NFC
This revision refactors and cleans up a bunch of things to simplify StructuredOpInterface
before work can proceed on Linalg on tensors:
- break out pieces of the StructuredOps trait that are part of the StructuredOpInterface,
- drop referenceIterators and referenceIndexingMaps that end up being more confusing than useful,
- drop NamedStructuredOpTrait
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e003fd15d0b1..ac6e9317fa32 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -130,21 +130,22 @@ def CopyOp : LinalgStructured_Op<"copy", [
let extraClassDeclaration = libraryCallName # [{
// Rank-polymorphic.
// filling_value -> O(ivs) with parallel iterators.
- llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
- unsigned nPar = input().getType().cast<ShapedType>().getRank();
- return SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName());
+ ArrayAttr iterator_types() {
+ unsigned nPar = getInputShapedType(0).getRank();
+ return Builder(getContext()).getStrArrayAttr(
+ SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName()));
}
// I(input_perm(ivs)) -> O(output_perm(ivs))
- llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+ ArrayAttr indexing_maps() {
MLIRContext *context = getContext();
auto maybeInputMap = inputPermutation();
auto maybeOutputMap = outputPermutation();
unsigned inputRank = getInputShapedType(0).getRank();
unsigned outputRank = getOutputShapedType(0).getRank();
- return SmallVector<AffineMap, 8>{
+ return Builder(getContext()).getAffineMapArrayAttr({
extractOrIdentityMap(maybeInputMap, inputRank, context),
- extractOrIdentityMap(maybeOutputMap, outputRank, context)};
+ extractOrIdentityMap(maybeOutputMap, outputRank, context)});
}
Value getSource() { return input();}
@@ -163,16 +164,17 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
let extraClassDeclaration = libraryCallName # [{
// Rank-polymorphic.
// filling_value -> O(ivs) with parallel iterators.
- llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
- unsigned nPar = output().getType().cast<ShapedType>().getRank();
- return SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName());
+ ArrayAttr iterator_types() {
+ unsigned nPar = getOutputShapedType(0).getRank();
+ return Builder(getContext()).getStrArrayAttr(
+ SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName()));
}
- llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+ ArrayAttr indexing_maps() {
MLIRContext *context = getContext();
// filling_value -> O(ivs)
- return SmallVector<AffineMap, 8>{
- extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)};
+ return Builder(getContext()).getAffineMapArrayAttr({
+ extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
}
}];
@@ -295,7 +297,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
getNumOutputFeatureDimensions();
}
- llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
+ ArrayAttr iterator_types() {
// Outer parallel loops are always the number of output dimensions; i.e.
// [b, xs, q] in the TF notation above.
unsigned nPar = getOutputShapedType(0).getRank();
@@ -310,7 +312,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
iters.reserve(nPar + nRed + nWin);
iters.append(nRed, getReductionIteratorTypeName());
iters.append(nWin, getWindowIteratorTypeName());
- return iters;
+ return Builder(getContext()).getStrArrayAttr(iters);
}
// F(z0, ..., zN-1, q, k) *
@@ -318,7 +320,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
// -> O(b, x0, ..., xN-1, k)
// for N equal to `nWindow`. If there is no padding attribute, it will be
// ignored.
- llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+ ArrayAttr indexing_maps() {
MLIRContext *context = getContext();
auto nWin = getNumWindowLoops();
assert(nWin > 0 && "expected at least one window dimension");
@@ -343,7 +345,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
auto zs = makeAffineDimExprs(nWin, idx, context);
// Construct the weighedSum expression.
auto ws = weightedPoolingInputIndex(*this, xs, zs);
- return SmallVector<AffineMap, 8>{
+ return Builder(getContext()).getAffineMapArrayAttr({
// filter[z[0], ..., z[N-1], q, k]
AffineMap::get(idx, 0, concat(concat(zs, qs), ks), context),
// input[b,
@@ -353,7 +355,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
// q]
AffineMap::get(idx, 0, concat(concat(bs, ws), qs), context),
// output[b, x[0], ..., x[N-1], k]
- AffineMap::get(idx, 0, concat(concat(bs, xs), ks), context)};
+ AffineMap::get(idx, 0, concat(concat(bs, xs), ks), context)});
}
}];
@@ -384,7 +386,7 @@ class SingleInputPoolingBase_Op<string mnemonic>
OptionalAttr<I64ElementsAttr>:$padding);
let extraClassDeclaration = commonUtils# [{
- llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
+ ArrayAttr iterator_types() {
// Outer parallel loops are always the number of output dimensions.
unsigned nPar = getOutputShapedType(0).getRank();
// The window loops has the same number loops with output dimensions.
@@ -392,10 +394,10 @@ class SingleInputPoolingBase_Op<string mnemonic>
SmallVector<StringRef, 8> iters(nPar, getParallelIteratorTypeName());
iters.reserve(nPar + nWin);
iters.append(nWin, getWindowIteratorTypeName());
- return iters;
+ return Builder(getContext()).getStrArrayAttr(iters);
}
- llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+ ArrayAttr indexing_maps() {
MLIRContext *context = getContext();
auto nPar = getNumParallelLoops();
auto nWin = getNumWindowLoops();
@@ -406,14 +408,13 @@ class SingleInputPoolingBase_Op<string mnemonic>
// Construct the weighedSum expression.
auto inputDims =
weightedPoolingInputIndex(*this, outputDims, windowDims);
- return SmallVector<AffineMap, 8>{
+ return Builder(getContext()).getAffineMapArrayAttr({
// input
AffineMap::get(idx, 0, inputDims, context),
// windowDims
AffineMap::get(idx, 0, windowDims, context),
// output
- AffineMap::get(idx, 0, outputDims, context)
- };
+ AffineMap::get(idx, 0, outputDims, context)});
}
}];
@@ -466,7 +467,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
OptionalAttr<StrAttr>:$library_call,
Confined<OptionalAttr<I64Attr>,
[IntMinValue<0>]>:$symbol_source);
- let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
+ let results = (outs Variadic<AnyRankedTensor>:$output_lis);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = [{
SmallVector<StringRef, 8> linalgTraitAttrNames() {
@@ -485,16 +486,6 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
return library_call().hasValue() ? library_call().getValue() : "";
}
- llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
- llvm_unreachable(
- "No such thing as reference iterator types for a generic op.");
- }
-
- llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
- llvm_unreachable(
- "No such thing as reference indexing maps for a generic op.");
- }
-
llvm::Optional<unsigned> getSymbolSource() {
auto ss = symbol_source();
return ss.hasValue() ?
@@ -807,8 +798,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
-def NamedStructuredOpTraits : NativeOpTrait<"linalg::NamedStructuredOpTraits">;
-
class LinalgNamedStructured_Op<string mnemonic, list<OpTrait> props>
: LinalgStructuredBase_Op<mnemonic, props> {
string spec = ?;
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 82882b083b2d..f32b70efd87e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -23,168 +23,486 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
// Loop types handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
- "Return the number of parallel loops within the current operation.",
- "unsigned", "getNumParallelLoops"
+ /*desc=*/[{
+ Return the number of parallel loops within the current operation.
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumParallelLoops",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getNumIterators(getParallelIteratorTypeName(),
+ $_op.iterator_types());
+ }]
>,
InterfaceMethod<
- "Return the number of reduction loops within the current operation.",
- "unsigned", "getNumReductionLoops"
+ /*desc=*/[{
+ Return the number of reduction loops within the current operation.
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumReductionLoops",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getNumIterators(getReductionIteratorTypeName(),
+ $_op.iterator_types());
+ }]
>,
InterfaceMethod<
- "Return the number of window loops within the current operation.",
- "unsigned", "getNumWindowLoops"
+ /*desc=*/[{
+ Return the number of window loops within the current operation.
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumWindowLoops",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getNumIterators(getWindowIteratorTypeName(),
+ $_op.iterator_types());
+ }]
>,
InterfaceMethod<
- "Return the number of loops within the current operation.",
- "unsigned", "getNumLoops">,
-
+ /*desc=*/[{
+ Return the total number of loops within the current operation.
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumLoops",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getNumIterators($_op.iterator_types());
+ }]
+ >,
InterfaceMethod<
- [{Returns true if the current operation has only one loop and it's a
- reduction loop}],
- "bool", "hasSingleReductionLoop">,
-
+ /*desc=*/[{
+ Returns true if the current operation has only one loop and it's a
+ reduction loop.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasSingleReductionLoop",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto iters = $_op.iterator_types();
+ return iters.size() == 1 &&
+ getNumIterators(getReductionIteratorTypeName(), iters) == 1;
+ }]>,
//===------------------------------------------------------------------===//
- // Input arguments handling.
+ // Num input/output arguments handling.
//===------------------------------------------------------------------===//
+ // These special methods must be defined by each op that wants to implement
+ // the LinalgStructuredInterface. For now, this is either:
+ // - inherited statically by using the NInputs<unsigned> or
+ // NOutputs<unsigned> traits.
+ // - derived from args_in/args_out attributes (for linalg.generic and
+ // linalg.indexed_generic ops).
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of inputs from the current operation.
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumInputs"
+ >,
InterfaceMethod<
- "Return the number of inputs from the current operation.",
- "unsigned", "getNumInputs"
+ /*desc=*/[{
+ Return the number of outputs from the current operation.
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumOutputs"
>,
- InterfaceMethod<"Return the input view at the given index.",
- "Value", "getInput", (ins "unsigned":$i)
+ //===------------------------------------------------------------------===//
+ // Input arguments handling.
+ //===------------------------------------------------------------------===//
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the `i`-th input value.
+ The `i^th` input argument is always the `i^th` operand regardless of
+ whether we have tensors or buffers.
+ }],
+ /*retTy=*/"Value",
+ /*methodName=*/"getInput",
+ /*args=*/(ins "unsigned":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(i < $_op.getNumInputs());
+ return this->getOperation()->getOperand(i);
+ }]
>,
- InterfaceMethod<[{
+ InterfaceMethod<
+ /*desc=*/[{
Return the index of the given input value `v`, or `None` if the value is
not an input.
}],
- "llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value":$v)
+ /*retTy=*/"llvm::Optional<unsigned>",
+ /*methodName=*/"getIndexOfInput",
+ /*args=*/(ins "Value":$value),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto it = llvm::find(getInputs(), value);
+ if (it != getInputs().end())
+ return it - getInputs().begin();
+ return llvm::None;
+ }]
>,
InterfaceMethod<
- "Return the input operands from the current operation.",
- "Operation::operand_range", "getInputs"
- >,
- InterfaceMethod<[{
+ /*desc=*/[{
Return the `i`-th input shaped type, irrespective of buffer or tensor
type.
- }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>,
- InterfaceMethod<[{
+ }],
+ /*retTy=*/"ShapedType",
+ /*methodName=*/"getInputShapedType",
+ /*args=*/(ins "unsigned":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getInput(i).getType().template cast<ShapedType>();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the input operands from the current operation.
+ }],
+ /*retTy=*/"Operation::operand_range",
+ /*methodName=*/"getInputs",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto range = this->getOperation()->getOperands();
+ return {range.begin(), range.begin() + $_op.getNumInputs()};
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
Return the subset of input operands that are of ranked tensor type.
- }], "SmallVector<RankedTensorType, 4>", "getInputTensorTypes">,
+ }],
+ /*retTy=*/"SmallVector<RankedTensorType, 4>",
+ /*methodName=*/"getInputTensorTypes" ,
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ SmallVector<RankedTensorType, 4> res;
+ for (Type type : getInputs().getTypes())
+ if (auto t = type.template dyn_cast<RankedTensorType>())
+ res.push_back(t);
+ return res;
+ }]
+ >,
//===------------------------------------------------------------------===//
// Output arguments handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
- "Return the number of outputs from the current operation.",
- "unsigned", "getNumOutputs"
- >,
- InterfaceMethod<"Return the output buffer at the given index.",
- "Value", "getOutputBuffer", (ins "unsigned":$i)
+ /*desc=*/[{
+ Return the output buffer at the given index, asserts that this is a
+ buffer operand and not a tensor result.
+ The `i^th` output argument is an operand (resp. a return value) iff it
+ is a value of buffer type (resp. a return value of tensor type).
+ }],
+ /*retTy=*/"Value",
+ /*methodName=*/"getOutputBuffer",
+ /*args=*/(ins "unsigned":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ // Output buffers are passed as output buffer operands (side-effecting).
+ // Output tensors are results.
+ // The union of the 2 are all the outputs and we want to ensure i does
+ // not overflow the buffer operands.
+ assert(i + this->getOperation()->getNumResults() < $_op.getNumOutputs()
+ && "overflowing output buffer index");
+ return this->getOperation()->getOperand($_op.getNumInputs() + i);
+ }]
>,
- InterfaceMethod<[{
+ InterfaceMethod<
+ /*desc=*/[{
Return the index of the given buffer value, or `None` if the value is
not part of the output buffers.
}],
- "llvm::Optional<unsigned>", "getIndexOfOutputBuffer", (ins "Value":$view)
+ /*retTy=*/"llvm::Optional<unsigned>",
+ /*methodName=*/"getIndexOfOutputBuffer",
+ /*args=*/(ins "Value":$value),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto it = llvm::find(getOutputBuffers(), value);
+ if (it != getOutputBuffers().end())
+ return it - getOutputBuffers().begin();
+ return llvm::None;
+ }]
>,
- InterfaceMethod<[{
+ InterfaceMethod<
+ /*desc=*/[{
Return the type of the output buffer at the given index.
- }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>,
- InterfaceMethod<[{
+ }],
+ /*retTy=*/"MemRefType",
+ /*methodName=*/"getOutputBufferType",
+ /*args=*/(ins "unsigned":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getOutputBuffer(i).getType().template cast<MemRefType>();
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Return the `i`-th output shaped type, irrespective of buffer or tensor
type.
- }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>,
- InterfaceMethod<[{
+ }],
+ /*retTy=*/"ShapedType",
+ /*methodName=*/"getOutputShapedType",
+ /*args=*/(ins "unsigned":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getShapedType(i + $_op.getNumInputs());
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Return the results that are of ranked tensor type.
- }], "SmallVector<RankedTensorType, 4>", "getOutputTensorTypes">,
+ }],
+ /*retTy=*/"SmallVector<RankedTensorType, 4>",
+ /*methodName=*/"getOutputTensorTypes",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ SmallVector<RankedTensorType, 4> res;
+ for (Type type : this->getOperation()->getResults().getTypes())
+ res.push_back(type.template cast<RankedTensorType>());
+ return res;
+ }]>,
InterfaceMethod<
- "Return the output buffers (operands) from the current operation.",
- "Operation::operand_range", "getOutputBuffers"
+ /*desc=*/[{
+ Return the output buffers (operands) from the current operation.
+ }],
+ /*retTy=*/"Operation::operand_range",
+ /*methodName=*/"getOutputBuffers",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto range = this->getOperation()->getOperands();
+ return {range.begin() + $_op.getNumInputs(),
+ range.begin() + getNumInputsAndOutputBuffers()};
+ }]
>,
//===------------------------------------------------------------------===//
// Input and Output arguments handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
- "Return one single buffer at position `$i`.",
- "Value", "getBuffer", (ins "unsigned":$i)
+ /*desc=*/[{
+ Return one single buffer at position `$i`.
+ }],
+ /*retTy=*/"Value",
+ /*methodName=*/"getBuffer",
+ /*args=*/(ins "unsigned":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index");
+ return this->getOperation()->getOperand(i);
+ }]
>,
InterfaceMethod<
- "Return the number of inputs and outputs, irrespective of their buffer "
- "or tensor type.",
- "unsigned", "getNumInputsAndOutputs"
+ /*desc=*/[{
+ Return the number of inputs and outputs, irrespective of their buffer or
+ tensor type.
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumInputsAndOutputs",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.getNumInputs() + $_op.getNumOutputs();
+ }]
>,
InterfaceMethod<
- "Return the number of inputs, irrespective of their buffer or tensor "
- "type, and output buffers",
- "unsigned", "getNumInputsAndOutputBuffers"
+ /*desc=*/[{
+ Return the number of inputs, irrespective of their buffer or tensor type
+ and output buffers
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumInputsAndOutputBuffers",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.getNumInputs() + $_op.getNumOutputs() -
+ this->getOperation()->getNumResults();
+ }]
>,
InterfaceMethod<
- "Return the range over inputs (irrespective of type) and output buffers.",
- "Operation::operand_range", "getInputsAndOutputBuffers"
+ /*desc=*/[{
+ Return the range over inputs (irrespective of type) and output buffers.
+ }],
+ /*retTy=*/"Operation::operand_range",
+ /*methodName=*/"getInputsAndOutputBuffers",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto range = this->getOperation()->getOperands();
+ return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()};
+ }]
>,
InterfaceMethod<
- "Return the shaped types for all the inputs and outputs",
- "SmallVector<ShapedType, 4>", "getInputOutputShapedTypes"
+ /*desc=*/[{
+ Return the `i`-th shaped type, there are 3 cases:
+ 1. if `i < $_op.getNumInputs()` then return `getInputShapedType(i)`;
+ otherwise
+ 2. if `i < getNumInputsAndOutputBuffers()` then return the
+ `getOutputBufferType(i - $_op.getNumInputs())`; otherwise
+ 3. return the `i - getNumInputsAndOutputBuffers()` result type.
+ }],
+ /*retTy=*/"ShapedType",
+ /*methodName=*/"getShapedType",
+ /*args=*/(ins "unsigned":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ if (i < $_op.getNumInputs())
+ return getInputShapedType(i);
+ if (i < getNumInputsAndOutputBuffers())
+ return getOutputBufferType(i - $_op.getNumInputs());
+ return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()];
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the shaped types for all the inputs and outputs
+ }],
+ /*retTy=*/"SmallVector<ShapedType, 4>",
+ /*methodName=*/"getInputOutputShapedTypes",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ SmallVector<Type, 4> inputOutputTypes(
+ this->getOperation()->operand_type_begin(),
+ this->getOperation()->operand_type_end());
+ inputOutputTypes.append(this->getOperation()->result_type_begin(),
+ this->getOperation()->result_type_end());
+ return llvm::to_vector<4>(
+ llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType {
+ return type.cast<ShapedType>();
+ }));
+ }]
>,
//===------------------------------------------------------------------===//
// Other interface methods.
//===------------------------------------------------------------------===//
InterfaceMethod<
- "Return the reference iterators for this named op (if any are "
- "specified). These reference iterators are used to specify the default "
- "behavior of the op. Typically this would be a static method but in "
- "order to allow rank-polymorphic ops, this needs to be per object "
- "instance. Named ops must define referenceIterators, even if empty for "
- "the 0-D case. Generic ops on the other hand have a None "
- "`referenceIterators`",
- "llvm::Optional<SmallVector<StringRef, 8>>", "referenceIterators"
+ /*desc=*/[{
+ Return the iterator types attribute within the current operation.
+ }],
+ /*retTy=*/"ArrayAttr",
+ /*methodName=*/"iterator_types",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.iterator_types();
+ }]
>,
InterfaceMethod<
- "Return the reference indexing maps for this named op (if any are "
- "specified). Typically this would be a static method but in order to "
- "allow rank-polymorphic ops, this needs to be per object instance. Named "
- "ops must define referenceIterators, even if empty for the 0-D case. "
- "Generic ops on the other hand have a None `referenceIndexingMaps`",
- "llvm::Optional<SmallVector<AffineMap, 8>>", "referenceIndexingMaps"
+ /*desc=*/[{
+ Return the indexing maps attribute within the current operation.
+ }],
+ /*retTy=*/"ArrayAttr",
+ /*methodName=*/"indexing_maps"
>,
InterfaceMethod<
- "Return the iterator types attribute within the current operation.",
- "ArrayAttr", "iterator_types"
+ /*desc=*/[{
+ Return the indexing maps within the current operation.
+ }],
+ /*retTy=*/"SmallVector<AffineMap, 4>",
+ /*methodName=*/"getIndexingMaps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return llvm::to_vector<4>(
+ llvm::map_range($_op.indexing_maps(),
+ [](Attribute attr) -> AffineMap {
+ return attr.cast<AffineMapAttr>().getValue();
+ }));
+ }]
>,
InterfaceMethod<
- "Return the indexing maps attribute within the current operation.",
- "ArrayAttr", "indexing_maps"
+ /*desc=*/[{
+ Return the input or output indexing map at index `i`.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getIndexingMap",
+ /*args=*/(ins "unsigned":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(i < getNumInputsAndOutputs());
+ return $_op.indexing_maps()
+ .getValue()[i]
+ .template cast<AffineMapAttr>()
+ .getValue();
+ }]
>,
InterfaceMethod<
- "Return the indexing maps within the current operation.",
- "SmallVector<AffineMap, 4>", "getIndexingMaps"
- >,
- InterfaceMethod<"Return the input or output indexing map at index `i`.",
- "AffineMap", "getIndexingMap", (ins "unsigned":$i)
- >,
- InterfaceMethod<"Return the input indexing map at index `i`.",
- "AffineMap", "getInputIndexingMap", (ins "unsigned":$i)
+ /*desc=*/[{
+ Return the input indexing map at index `i`.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getInputIndexingMap",
+ /*args=*/(ins "unsigned":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(i < $_op.getNumInputs());
+ return $_op.indexing_maps()
+ .getValue()[i]
+ .template cast<AffineMapAttr>()
+ .getValue();
+ }]
>,
- InterfaceMethod<"Return the output indexing map at index `i`.",
- "AffineMap", "getOutputIndexingMap", (ins "unsigned":$i)
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the output indexing map at index `i`.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getOutputIndexingMap",
+ /*args=*/(ins "unsigned":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(i < $_op.getNumOutputs());
+ return $_op.indexing_maps()
+ .getValue()[i + $_op.getNumInputs()]
+ .template cast<AffineMapAttr>()
+ .getValue();
+ }]
>,
- InterfaceMethod<[{
+ InterfaceMethod<
+ /*desc=*/[{
Return whether the op has only MemRef input and outputs.
- }], "bool", "hasBufferSemantics">,
- InterfaceMethod<[{
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasBufferSemantics",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return this->getOperation()->getNumResults() == 0 &&
+ llvm::all_of(getInputs(),
+ [](Value v) { return v.getType().isa<MemRefType>(); });
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
Return whether the op has only RankedTensor input and outputs.
- }], "bool", "hasTensorSemantics">,
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasTensorSemantics",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto isTensorType = [](Value v) {
+ return v.getType().isa<RankedTensorType>();
+ };
+ return llvm::all_of(getInputs(), isTensorType) &&
+ llvm::all_of(this->getOperation()->getResults(), isTensorType);
+ }]
+ >,
//===------------------------------------------------------------------===//
// Other static interface methods.
//===------------------------------------------------------------------===//
- StaticInterfaceMethod<[{
+ StaticInterfaceMethod<
+ /*desc=*/[{
Create an operation of the current type with the given location,
operands, and attributes.
}],
- "Operation *", "create",
+ /*retTy=*/"Operation *",
+ /*methodName=*/"create",
(ins "OpBuilder &":$builder, "Location":$loc,
"ValueRange":$operands,
"ArrayRef<NamedAttribute>":$attributes), [{
@@ -192,11 +510,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
attributes);
}]
>,
- InterfaceMethod<[{
+ InterfaceMethod<
+ /*desc=*/[{
Clone the current operation with the given location and operands. This
is used to abstract away the optional underlying region creation.
}],
- "Operation *", "clone",
+ /*retTy=*/"Operation *",
+ /*methodName=*/"clone",
(ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{
BlockAndValueMapping map;
unsigned numRegions = $_op.getOperation()->getNumRegions();
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index 8dda7d0a1445..c4790ca617f1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -49,8 +49,8 @@ template <unsigned N> class NOutputs {
};
};
-/// This class provides the API for structured ops that are known to operate on
-/// buffers or tensors. This trait must be used in conjunction with an op
+/// This class provides a verifier for structured ops that are known to operate
+/// on buffers or tensors. This trait must be used in conjunction with an op
/// definition or a trait that provides the methods `getNumInputs` and
/// `getNumOutputs`. Use as a trait as follows:
///
@@ -59,324 +59,18 @@ template <unsigned N> class NOutputs {
template <typename ConcreteType>
class StructuredOpTraits
: public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> {
-private:
- /// Return the number of inputs, irrespective of their buffer or tensor type.
- /// For internal use only.
- unsigned nInputs() {
- return cast<ConcreteType>(this->getOperation()).getNumInputs();
- }
- /// Return the number of outputs, irrespective of their buffer or tensor type.
- /// For internal use only.
- unsigned nOutputs() {
- return cast<ConcreteType>(this->getOperation()).getNumOutputs();
- }
-
public:
- //==========================================================================//
- // Loop types handling.
- //==========================================================================//
- unsigned getNumParallelLoops() {
- return getNumIterators(
- getParallelIteratorTypeName(),
- cast<ConcreteType>(this->getOperation()).iterator_types());
- }
- unsigned getNumReductionLoops() {
- return getNumIterators(
- getReductionIteratorTypeName(),
- cast<ConcreteType>(this->getOperation()).iterator_types());
- }
- unsigned getNumWindowLoops() {
- return getNumIterators(
- getWindowIteratorTypeName(),
- cast<ConcreteType>(this->getOperation()).iterator_types());
- }
- unsigned getNumLoops() {
- return getNumIterators(
- cast<ConcreteType>(this->getOperation()).iterator_types());
- }
-
- bool hasSingleReductionLoop() {
- auto iterators = cast<ConcreteType>(this->getOperation()).iterator_types();
- return iterators.size() == 1 &&
- getNumIterators(getReductionIteratorTypeName(), iterators);
- }
-
- //==========================================================================//
- // Input arguments handling.
- //==========================================================================//
- // The `i^th` input argument is always the `i^th` operand regardless of
- // whether we have tensors or buffers.
- //
- /// Return the `i`-th input value.
- Value getInput(unsigned i) {
- assert(i < nInputs());
- return this->getOperation()->getOperand(i);
- }
- /// Return the index of `value` in the list of inputs if found, llvm::None
- /// otherwise.
- Optional<unsigned> getIndexOfInput(Value value) {
- auto it = llvm::find(getInputs(), value);
- if (it != getInputs().end())
- return it - getInputs().begin();
- return llvm::None;
- }
- /// Return the `i`-th input shaped type, irrespective of buffer or tensor
- /// type.
- ShapedType getInputShapedType(unsigned i) {
- return getInput(i).getType().template cast<ShapedType>();
- }
- /// Return the range over inputs.
- Operation::operand_range getInputs() {
- auto range = this->getOperation()->getOperands();
- return {range.begin(), range.begin() + nInputs()};
- }
- /// Query the subset of input operands that are of ranked tensor type.
- SmallVector<RankedTensorType, 4> getInputTensorTypes() {
- SmallVector<RankedTensorType, 4> res;
- for (Type type : getInputs().getTypes())
- if (auto t = type.template dyn_cast<RankedTensorType>())
- res.push_back(t);
- return res;
- }
-
- //==========================================================================//
- // Output arguments handling.
- //==========================================================================//
- // The `i^th` output argument is an operand (resp. a return value) iff it is
- // a value of buffer type (resp. a return value of tensor type).
-
- /// Return the `i`-th output, asserts that this is a buffer operand and not
- /// a tensor result.
- Value getOutputBuffer(unsigned i) {
- assert(i + this->getOperation()->getNumResults() < nOutputs() &&
- "overflowing output buffer index");
- return this->getOperation()->getOperand(nInputs() + i);
- }
- /// Return the index of `value` in the list of output buffers if found,
- /// llvm::None otherwise.
- Optional<unsigned> getIndexOfOutputBuffer(Value value) {
- auto it = llvm::find(getOutputBuffers(), value);
- if (it != getOutputBuffers().end())
- return it - getOutputBuffers().begin();
- return llvm::None;
- }
- /// Return the `i`-th output buffer type.
- MemRefType getOutputBufferType(unsigned i) {
- return getOutputBuffer(i).getType().template cast<MemRefType>();
- }
- /// Return the `i`-th output shaped type, irrespective of buffer of tensor
- /// type.
- ShapedType getOutputShapedType(unsigned i) {
- return getShapedType(i + nInputs());
- }
- /// Query the subset of results that are of ranked tensor type.
- SmallVector<RankedTensorType, 4> getOutputTensorTypes() {
- SmallVector<RankedTensorType, 4> res;
- for (Type type : this->getOperation()->getResults().getTypes())
- res.push_back(type.template cast<RankedTensorType>());
- return res;
- }
- /// Return the range over outputs.
- Operation::operand_range getOutputBuffers() {
- auto range = this->getOperation()->getOperands();
- return {range.begin() + nInputs(),
- range.begin() + getNumInputsAndOutputBuffers()};
- }
-
- //==========================================================================//
- // Input and Output arguments handling.
- //==========================================================================//
- Value getBuffer(unsigned i) {
- assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index");
- return this->getOperation()->getOperand(i);
- }
- /// Return the number of inputs and outputs, irrespective of their buffer or
- /// tensor type.
- unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
- /// Return the number of inputs, irrespective of their buffer or tensor type,
- /// and output buffers.
- unsigned getNumInputsAndOutputBuffers() {
- assert(this->getOperation()->getNumResults() <= nOutputs());
- return nInputs() + nOutputs() - this->getOperation()->getNumResults();
- }
- /// Return the range over inputs (irrespective of type) and output buffers.
- Operation::operand_range getInputsAndOutputBuffers() {
- auto range = this->getOperation()->getOperands();
- return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()};
- }
- /// Return the `i`-th shaped type, there are 3 cases:
- /// 1. if `i < nInputs()` then return `getInputShapedType(i)`; otherwise
- /// 2. if `i < getNumInputsAndOutputBuffers()` then return the
- /// `getOutputBufferType(i - nInputs())`; otherwise
- /// 3. return the `i - getNumInputsAndOutputBuffers()` result type.
- ShapedType getShapedType(unsigned i) {
- if (i < nInputs())
- return getInputShapedType(i);
- if (i < getNumInputsAndOutputBuffers())
- return getOutputBufferType(i - nInputs()).template cast<ShapedType>();
- return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()]
- .template cast<ShapedType>();
- }
- /// Return the shaped types for all the inputs and outputs
- SmallVector<ShapedType, 4> getInputOutputShapedTypes() {
- SmallVector<Type, 4> inputOutputTypes(
- this->getOperation()->operand_type_begin(),
- this->getOperation()->operand_type_end());
- inputOutputTypes.append(this->getOperation()->result_type_begin(),
- this->getOperation()->result_type_end());
- return llvm::to_vector<4>(
- llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType {
- return type.cast<ShapedType>();
- }));
- }
-
- //==========================================================================//
- // Other interface methods.
- //==========================================================================//
-
- // Get or build the indexing_maps ArrayAttr.
- ArrayAttr iterator_types() {
- // Return the attribute if it is present.
- if (auto attr = this->getOperation()->getAttr("iterator_types"))
- return attr.template cast<ArrayAttr>();
-
- // If not, form the attribute using the reference iterator types for the
- // ConcreteType.
- auto maybeReferenceIteratorTypes =
- cast<ConcreteType>(this->getOperation()).referenceIterators();
-
- // If there is no reference, this must be a generic op.
- // TODO: Traits are used to define ops. Split into cpp to avoid cyclic
- // dependency.
- auto name = this->getOperation()->getName().getStringRef();
- if (!maybeReferenceIteratorTypes && name != "generic" &&
- name != "indexed_generic") {
- this->getOperation()->dump();
- llvm_unreachable("Op missing referenceIterators");
- }
-
- // If we have a reference, build the reference attribute and set it in the
- // op before returning.
- auto *ctx = this->getOperation()->getContext();
- auto attrRange = llvm::map_range(*maybeReferenceIteratorTypes,
- [ctx](StringRef str) -> Attribute {
- return StringAttr::get(str, ctx);
- });
- auto attr = ArrayAttr::get(llvm::to_vector<4>(attrRange), ctx);
- // TODO: Need to memoize this. Can't just store as an attribute atm as it
- // will impact parser, printer and tests.
- // this->getOperation()->setAttr("iterator_types", attr);
- return attr;
- }
-
- // Get or build the indexing_maps ArrayAttr.
- ArrayAttr indexing_maps() {
- // Return the attribute if it is present.
- if (auto attr = this->getOperation()->getAttr("indexing_maps"))
- return attr.template cast<ArrayAttr>();
-
- // If not, form the attribute using the reference indexing map for the
- // ConcreteType.
- auto maybeReferenceIndexingMaps =
- cast<ConcreteType>(this->getOperation()).referenceIndexingMaps();
-
- // If there is no reference, this must be a generic op.
- auto name = this->getOperation()->getName().getStringRef();
- if (!maybeReferenceIndexingMaps && name != "generic" &&
- name != "indexed_generic") {
- this->getOperation()->dump();
- llvm_unreachable("Op missing referenceIndexingMaps");
- }
-
- // If we have a reference, build the reference attribute and set it in the
- // op before returning.
- auto *ctx = this->getOperation()->getContext();
- auto attrRange =
- llvm::map_range(*maybeReferenceIndexingMaps, [ctx](AffineMap map) {
- // 0-D corner case because there is no such thing as a concrete empty
- // map type.
- if (!map)
- map = AffineMap::get(0, 0, getAffineConstantExpr(0, ctx));
- return AffineMapAttr::get(map);
- });
- SmallVector<Attribute, 4> attrs{attrRange.begin(), attrRange.end()};
- auto attr = ArrayAttr::get(attrs, ctx);
- // TODO: Need to memoize this. Can't just store as an attribute atm as it
- // will impact parser, printer and tests.
- // this->getOperation()->setAttr("indexing_maps", attr);
- return attr;
- }
-
- SmallVector<AffineMap, 4> getIndexingMaps() {
- return llvm::to_vector<4>(
- llvm::map_range(indexing_maps(), [](Attribute attr) -> AffineMap {
- return attr.cast<AffineMapAttr>().getValue();
- }));
- }
-
- AffineMap getIndexingMap(unsigned i) {
- assert(i < getNumInputsAndOutputs());
- return indexing_maps()
- .getValue()[i]
- .template cast<AffineMapAttr>()
- .getValue();
- }
-
- AffineMap getInputIndexingMap(unsigned i) {
- assert(i < nInputs());
- return indexing_maps()
- .getValue()[i]
- .template cast<AffineMapAttr>()
- .getValue();
- }
-
- AffineMap getOutputIndexingMap(unsigned i) {
- assert(i < nOutputs());
- return indexing_maps()
- .getValue()[i + nInputs()]
- .template cast<AffineMapAttr>()
- .getValue();
- }
-
- /// Query whether the op has only buffer inputs and no returns.
- bool hasBufferSemantics() {
- return this->getOperation()->getNumResults() == 0 &&
- llvm::all_of(getInputs(),
- [](Value v) { return v.getType().isa<MemRefType>(); });
- }
-
- /// Query whether the op has only tensor inputs and outputs.
- bool hasTensorSemantics() {
- auto isTensorType = [](Value v) {
- return v.getType().isa<RankedTensorType>();
- };
- return llvm::all_of(getInputs(), isTensorType) &&
- llvm::all_of(this->getOperation()->getResults(), isTensorType);
- }
-
- //==========================================================================//
- // Other static interface methods.
- //==========================================================================//
static LogicalResult verifyTrait(Operation *op) {
+ ConcreteType concreteOp = cast<ConcreteType>(op);
auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputBuffers();
if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands)))
return failure();
+ if (op->getNumResults() > concreteOp.getNumOutputs())
+ return op->emitError("unexpected #results > #outputs");
return success();
}
};
-/// This class provides the API for named Linalg StructuredOps.
-template <typename ConcreteType>
-class NamedStructuredOpTraits
- : public OpTrait::TraitBase<ConcreteType, NamedStructuredOpTraits> {
-public:
- static SmallVector<StringRef, 8> referenceIterators(TypeRange inputTypes,
- TypeRange outputTypes);
-
- static SmallVector<AffineMap, 8> referenceIndexingMaps(TypeRange inputTypes,
- TypeRange outputTypes);
-};
-
} // namespace linalg
} // namespace OpTrait
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 77eb64489477..7071cd385f77 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -260,13 +260,14 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
if (failed(BlockArgsVerifier<GenericOpType>::verify(op, region.front())))
return failure();
- auto attr = op.template getAttrOfType<IntegerAttr>("symbol_source");
- int64_t targetRank = 0;
- if (attr) {
- unsigned index = attr.getInt();
+ auto symbolSourceAttr =
+ op.template getAttrOfType<IntegerAttr>("symbol_source");
+ int64_t expectedNumSymbols = 0;
+ if (symbolSourceAttr) {
+ unsigned index = symbolSourceAttr.getInt();
if (index >= op.getNumOperands())
return op.emitOpError("symbol_source index out of range");
- targetRank = op.getShapedType(index).getRank();
+ expectedNumSymbols = op.getShapedType(index).getRank();
}
SmallVector<AffineMap, 4> indexingMaps;
@@ -278,9 +279,9 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
auto view = (idx < nInputViews) ? op.getInputShapedType(idx)
: op.getOutputShapedType(idx - nInputViews);
- if (m.getNumSymbols() != targetRank)
+ if (m.getNumSymbols() != expectedNumSymbols)
return op.emitOpError("expected the number of symbols in indexing_map #")
- << idx << " to match target rank";
+ << idx << " to match rank of operand `symbol_source`";
if (m.getNumDims() != nLoops)
return op.emitOpError("expected indexing_map #")
@@ -1246,15 +1247,9 @@ void buildNamedStructuredOpRegionAndAttributes(Builder &builder,
mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc());
NamedStructuredOpType::regionBuilder(*body);
- auto indexingMaps = builder.getAffineMapArrayAttr(
- NamedStructuredOpType::referenceIndexingMaps(operandTypes,
- tensorResultTypes));
- result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
+ // indexing_maps is an auto-generated method.
- auto iterators =
- builder.getStrArrayAttr(NamedStructuredOpType::referenceIterators(
- operandTypes, tensorResultTypes));
- result.addAttribute(getIteratorTypesAttrName(), iterators);
+ // iterator_types is an auto-generated method.
}
template <typename NamedStructuredOpType>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index c631c47099b0..3774aed7ad1f 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -113,7 +113,7 @@ func @generic_mismatched_num_returns(%arg0: memref<f32>) {
// -----
func @generic_symbol_in_map(%arg0: memref<i32>) {
- // expected-error @+1 {{expected the number of symbols in indexing_map #0 to match target rank}}
+ // expected-error @+1 {{expected the number of symbols in indexing_map #0 to match rank of operand `symbol_source`}}
linalg.generic {
args_in = 0,
args_out = 1,
@@ -514,3 +514,20 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?x
linalg.batch_matmul %a3, %b3, %c3 : (memref<?x?x?xf32>, memref<?x?xf32>, memref<?x?x?xf32>) -> ()
return
}
+
+// -----
+
+func @generic(%arg0: tensor<?x?xi4>) {
+ // expected-error @+1 {{unexpected #results > #outputs}}
+ linalg.generic {
+ args_in = 1,
+ args_out = 1,
+ indexing_maps = [ affine_map<(i) -> (i)> ],
+ iterator_types = ["parallel"]
+ } %arg0 {
+ ^bb(%0: i4) :
+ %1 = std.addi %0, %0: i4
+ linalg.yield %1, %1: i4, i4
+ } : tensor<?x?xi4> -> (tensor<?x?xi4>, tensor<?x?xi4>)
+ return
+}
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index d796d1917c03..aad983eb85d2 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -4,16 +4,15 @@
// ODS-LABEL: def Test1Op : LinalgNamedStructured_Op<"test1", [
// ODS-NEXT: NInputs<2>
// ODS-NEXT: NOutputs<1>
-// ODS-NEXT: NamedStructuredOpTraits
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
-// IMPL-LABEL: SmallVector<StringRef, 8> Test1Op::referenceIterators
+// IMPL-LABEL: ArrayAttr Test1Op::iterator_types() {
// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
//
-// IMPL: SmallVector<AffineMap, 8> Test1Op::referenceIndexingMaps
+// IMPL: ArrayAttr Test1Op::indexing_maps() {
// IMPL: AffineMap::get(2, 0, {d0, d1}, context),
// IMPL-NEXT: AffineMap::get(2, 0, {d1}, context),
-// IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) };
+// IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) });
//
// IMPL: void Test1Op::regionBuilder(Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
@@ -29,16 +28,15 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
// ODS-LABEL: def Test2Op : LinalgNamedStructured_Op<"test2", [
// ODS-NEXT: NInputs<2>
// ODS-NEXT: NOutputs<1>
-// ODS-NEXT: NamedStructuredOpTraits
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
-// IMPL-LABEL: SmallVector<StringRef, 8> Test2Op::referenceIterators
+// IMPL-LABEL: ArrayAttr Test2Op::iterator_types() {
// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
//
-// IMPL: SmallVector<AffineMap, 8> Test2Op::referenceIndexingMaps
+// IMPL: ArrayAttr Test2Op::indexing_maps() {
// IMPL: AffineMap::get(3, 0, {d0, d2}, context),
// IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}, context),
-// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) };
+// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) });
//
// IMPL: Test2Op::regionBuilder(Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
@@ -54,16 +52,15 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
// ODS-LABEL: def Test3Op : LinalgNamedStructured_Op<"test3", [
// ODS-NEXT: NInputs<2>
// ODS-NEXT: NOutputs<1>
-// ODS-NEXT: NamedStructuredOpTraits
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
-// IMPL-LABEL: SmallVector<StringRef, 8> Test3Op::referenceIterators
+// IMPL-LABEL: ArrayAttr Test3Op::iterator_types() {
// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
//
-// IMPL: SmallVector<AffineMap, 8> Test3Op::referenceIndexingMaps
+// IMPL: ArrayAttr Test3Op::indexing_maps() {
// IMPL: AffineMap::get(4, 0, {d0, d1, d3}, context),
// IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}, context),
-// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) };
+// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) });
//
// IMPL: Test3Op::regionBuilder(Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 92efef67e8f4..59d655684f48 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -974,19 +974,19 @@ class TCParser {
/// Parse and print the information for a TC def.
/// When `gen-ods-decl` is used, this prints the ODS declaration for the TC.
/// When `gen-impl` is used, this prints the C++ implementation for the extra
- /// methods defined in ODS (referenceIterators, referenceIndexingMaps and
- /// regionBuilder).
+ /// methods defined in ODS (`iterator_types`, `indexing_maps` and
+ /// `regionBuilder`).
LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os);
/// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
void printODS(llvm::raw_ostream &os, StringRef cppOpName,
StringRef linalgOpName);
- /// Print the C++ StructuredOpsInterface impl of `referenceIterators`.
+ /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
ComprehensionParsingState &state);
- /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
+ /// Print the C++ StructuredOpsInterface impl of `indexing_maps`.
void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName,
ComprehensionParsingState &state);
@@ -1446,7 +1446,6 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
const char *header = R"FMT( def {0} : LinalgNamedStructured_Op<"{1}", [
NInputs<{2}>,
NOutputs<{3}>,
- NamedStructuredOpTraits,
SingleBlockImplicitTerminator<"YieldOp">]> {
let arguments = (ins Variadic<LinalgOperand>:$views);
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
@@ -1465,16 +1464,9 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
return ::parseNamedStructuredOp<{0}>(parser, result);
}];
let extraClassDeclaration = [{{
- llvm::Optional<SmallVector<StringRef, 8>> referenceIterators();
- static SmallVector<StringRef, 8> referenceIterators(
- TypeRange inputTypes, TypeRange outputTypes);
-
- llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps();
- static SmallVector<AffineMap, 8> referenceIndexingMaps(
- TypeRange inputTypes, TypeRange outputTypes);
-
+ ArrayAttr iterator_types();
+ ArrayAttr indexing_maps();
static void regionBuilder(Block &block);
-
std::string getLibraryCallName() {{
return generateLibraryCallName(getOperation());
}
@@ -1492,20 +1484,14 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs);
}
-/// Print the C++ StructuredOpsInterface impl of `referenceIterators`.
+/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
void TCParser::printReferenceIterators(llvm::raw_ostream &os,
StringRef cppOpName,
ComprehensionParsingState &state) {
const char *referenceReferenceIteratorsFmt =
R"FMT(
- // This is temporary until we transition out of manually specified ops
- // that should be auto-generated with linalg-ods-gen.
- llvm::Optional<SmallVector<StringRef, 8>> {0}::referenceIterators() {{
- llvm_unreachable("Unexpected missing `iterator_types` attribute.");
- }
- SmallVector<StringRef, 8> {0}::referenceIterators(
- TypeRange inputTypes, TypeRange outputTypes) {
- return SmallVector<StringRef, 8>{{ {1} };
+ ArrayAttr {0}::iterator_types() {
+ return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef, 8>{{ {1} });
})FMT";
std::string iteratorsStr;
@@ -1542,16 +1528,11 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
R"FMT(
// This is temporary until we transition out of manually specified ops that
// should be auto-generated with linalg-ods-gen.
- llvm::Optional<SmallVector<AffineMap, 8>> {0}::referenceIndexingMaps() {{
- llvm_unreachable("Unexpected missing `indexing_maps` attribute.");
- }
- SmallVector<AffineMap, 8> {0}::referenceIndexingMaps(
- TypeRange inputTypes, TypeRange outputTypes) {
- assert(!inputTypes.empty() && "At least one input expected");
- MLIRContext *context = (*inputTypes.begin()).getContext();
+ ArrayAttr {0}::indexing_maps() {
+ MLIRContext *context = getContext();
AffineExpr {1};
bindDims(context, {1});
- return SmallVector<AffineMap, 8>{{ {2} };
+ return Builder(context).getAffineMapArrayAttr({ {2} });
})FMT";
// 2. Print a comma-separated list of identifiers for the AffineExpr in
More information about the Mlir-commits
mailing list