[Mlir-commits] [mlir] 2837991 - [mlir][linalg] Add support for memref inputs/outputs for `linalg.tiled_loop`.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Mar 18 08:11:21 PDT 2021
Author: Alexander Belyaev
Date: 2021-03-18T16:11:03+01:00
New Revision: 283799157e504597fc3034cc5fa02faa4e11fa58
URL: https://github.com/llvm/llvm-project/commit/283799157e504597fc3034cc5fa02faa4e11fa58
DIFF: https://github.com/llvm/llvm-project/commit/283799157e504597fc3034cc5fa02faa4e11fa58.diff
LOG: [mlir][linalg] Add support for memref inputs/outputs for `linalg.tiled_loop`.
Also use `ArrayAttr` to pass iterator pass to the TiledLoopOp builder.
Differential Revision: https://reviews.llvm.org/D98871
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index a706d67d2988..5a906ff2dafd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -15,6 +15,8 @@
include "mlir/IR/OpBase.td"
+def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
+
def Linalg_Dialect : Dialect {
let name = "linalg";
let description = [{
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 63bee92ded7c..d54efbe37a57 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -496,21 +496,25 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
let summary = "Linalg tiled loop operation";
let description = [{
This is a loop-like operation with additional properties. The arguments
- also include the input and the output tensors and the attributes to specify
- the iterator types. The body region of the loop contains `subtensor`
- operations applied to every tensor argument of TiledLoopOp.
+ also include the input and the output tensors or memrefs and the attributes
+ to specify the iterator types.
+
+ Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
+ to "parallel" type, when it is absent from the custom format.
+
+ Tensor-based version:
+
+ The body region of the loop contains `subtensor` operations applied to
+ every tensor argument of TiledLoopOp.
The body region must contain exactly one block that terminates with
`linalg.yield` with the operands resulting from `subtensor_insert`
operations.
- Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
- to "parallel" type, when it is absent from the custom format.
-
Example:
```mlir
- linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
+ %0 = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
outs(%out : tensor<24x64xi8>)
iterators("parallel") {
@@ -528,13 +532,40 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
linalg.yield %result : tensor<24x64xi8>
}
```
+
+ MemRef-based version:
+
+ The body region of the loop contains `subview` operations applied to
+ every memref argument of TiledLoopOp.
+
+ The body region must contain exactly one block that terminates with
+ `linalg.yield` with no operands.
+
+ Example:
+
+ ```mlir
+ linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
+ ins(%lhs, %rhs : memref<24x64xi8>, memref<24x64xi8>)
+ outs(%out : memref<24x64xi8>)
+ iterators("parallel") {
+ %lhs_sub = subview %lhs[%i, 0] [%c4, %c64] [1, 1]
+ : memref<24x64xi8> to memref<?x?xi8>
+ %rhs_sub = subview %rhs[%i, 0] [%c4, %c64] [1, 1]
+ : memref<24x64xi8> to memref<?x?xi8>
+ %out_sub = subview %out[%i, 0] [%c4, %c64] [1, 1]
+ : memref<24x64xi8> to memref<?x?xi8>
+
+ %result_sub = linalg.generic ...
+ linalg.yield
+ }
+ ```
}];
let arguments = (ins Variadic<Index>:$lowerBound,
Variadic<Index>:$upperBound,
Variadic<Index>:$step,
- Variadic<AnyRankedTensor>:$inputs,
- Variadic<AnyRankedTensor>:$outputs,
+ Variadic<LinalgOperand>:$inputs,
+ Variadic<LinalgOperand>:$outputs,
ArrayAttr:$iterator_types);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region SizedRegion<1>:$region);
@@ -542,7 +573,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
let builders = [
OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
"ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
- "ArrayRef<StringRef>":$iteratorTypes,
+ "ArrayAttr":$iteratorTypes,
CArg<"function_ref<void (OpBuilder &, Location, ValueRange)>",
"nullptr">:$bodyBuilderFn)>,
];
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f87a1eaeac8f..69aa7659b81c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -496,8 +496,6 @@ def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> {
//===----------------------------------------------------------------------===//
// Generic Linalg ops.
//===----------------------------------------------------------------------===//
-def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
-
class LinalgOperandOfRank<int rank>: Type<
And<[
LinalgOperand.predicate,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3b268d703a74..13cca7f19ee7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1744,7 +1744,7 @@ static LogicalResult verify(linalg::YieldOp op) {
void TiledLoopOp::build(
OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
ValueRange upperBounds, ValueRange steps, ValueRange inputs,
- ValueRange outputs, ArrayRef<StringRef> iteratorTypes,
+ ValueRange outputs, ArrayAttr iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
result.addOperands(lowerBounds);
result.addOperands(upperBounds);
@@ -1758,9 +1758,14 @@ void TiledLoopOp::build(
static_cast<int32_t>(steps.size()),
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
- result.addAttribute(getIteratorTypesAttrName(),
- builder.getStrArrayAttr(iteratorTypes));
- result.addTypes(outputs.getTypes());
+ result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
+
+ // Add output types for `RankedTensorType` output arguments.
+ for (Value output : outputs) {
+ Type outputType = output.getType();
+ if (outputType.isa<RankedTensorType>())
+ result.addTypes(outputType);
+ }
OpBuilder::InsertionGuard guard(builder);
unsigned numIVs = steps.size();
@@ -1771,8 +1776,8 @@ void TiledLoopOp::build(
if (bodyBuilderFn) {
builder.setInsertionPointToStart(bodyBlock);
bodyBuilderFn(builder, result.location, bodyBlock->getArguments());
+ TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
}
- TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
}
static void print(OpAsmPrinter &p, TiledLoopOp op) {
More information about the Mlir-commits
mailing list