[Mlir-commits] [mlir] 2ea6e13 - [mlir] Add an optional distributionTypes attribute to TiledLoopOp.
Alexander Belyaev
llvmlistbot at llvm.org
Tue May 25 11:05:16 PDT 2021
Author: Alexander Belyaev
Date: 2021-05-25T20:04:41+02:00
New Revision: 2ea6e13bf8189efb09cec89184b21f1db3de0d1c
URL: https://github.com/llvm/llvm-project/commit/2ea6e13bf8189efb09cec89184b21f1db3de0d1c
DIFF: https://github.com/llvm/llvm-project/commit/2ea6e13bf8189efb09cec89184b21f1db3de0d1c.diff
LOG: [mlir] Add an optional distributionTypes attribute to TiledLoopOp.
Differential Revision: https://reviews.llvm.org/D103104
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index c78624b70b7b5..6c46c0abfcb46 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -521,7 +521,8 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
%0 = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
outs(%out : tensor<24x64xi8>)
- iterators("parallel") {
+ iterators("parallel")
+ distribution("block_x") {
%lhs_sub = subtensor %lhs[%i, 0] [%c4, %c64] [1, 1]
: tensor<24x64xi8> to tensor<?x?xi8>
%rhs_sub = subtensor %rhs[%i, 0] [%c4, %c64] [1, 1]
@@ -551,7 +552,8 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
ins(%lhs, %rhs : memref<24x64xi8>, memref<24x64xi8>)
outs(%out : memref<24x64xi8>)
- iterators("parallel") {
+ iterators("parallel")
+ distribution("block_x") {
%lhs_sub = subview %lhs[%i, 0] [%c4, %c64] [1, 1]
: memref<24x64xi8> to memref<?x?xi8>
%rhs_sub = subview %rhs[%i, 0] [%c4, %c64] [1, 1]
@@ -570,11 +572,18 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
Variadic<Index>:$step,
Variadic<LinalgOperand>:$inputs,
Variadic<LinalgOperand>:$outputs,
- ArrayAttr:$iterator_types);
+ ArrayAttr:$iterator_types,
+ OptionalAttr<ArrayAttr>:$distribution_types);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region SizedRegion<1>:$region);
let builders = [
+ OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
+ "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
+ "ArrayAttr":$iteratorTypes, "Optional<ArrayAttr>":$distributionTypes,
+ CArg<"function_ref<void (OpBuilder &, Location, /*ivs=*/ValueRange,"
+ "/*inputs=*/ValueRange, /*outputs=*/ValueRange)>",
+ "nullptr">:$bodyBuilderFn)>,
OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
"ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
"ArrayAttr":$iteratorTypes,
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index d6ccea1df7a4e..14f8b29f66892 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -50,6 +50,12 @@ constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }
/// op's iterators.
constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; }
+/// Attribute name for the StrArrayAttr which encodes the distribution type for
+/// `linalg.tiled_loop`.
+constexpr StringRef getDistributionTypesAttrName() {
+ return "distribution_types";
+}
+
/// Attribute name for the StringAttr which encodes an optional documentation
/// string of the structured op.
constexpr StringRef getDocAttrName() { return "doc"; }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f3de81801e672..034bf188d175e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2075,6 +2075,18 @@ void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
function_ref<void(OpBuilder &, Location, ValueRange,
ValueRange, ValueRange)>
bodyBuilderFn) {
+ build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs,
+ iteratorTypes, llvm::None, bodyBuilderFn);
+}
+
+void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
+ ValueRange lowerBounds, ValueRange upperBounds,
+ ValueRange steps, ValueRange inputs, ValueRange outputs,
+ ArrayAttr iteratorTypes,
+ Optional<ArrayAttr> distributionTypes,
+ function_ref<void(OpBuilder &, Location, ValueRange,
+ ValueRange, ValueRange)>
+ bodyBuilderFn) {
result.addOperands(lowerBounds);
result.addOperands(upperBounds);
result.addOperands(steps);
@@ -2089,6 +2101,10 @@ void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
static_cast<int32_t>(outputs.size())}));
result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
+ if (distributionTypes.hasValue())
+ result.addAttribute(getDistributionTypesAttrName(),
+ distributionTypes.getValue());
+
// Add output types for `RankedTensorType` output arguments.
for (Value output : outputs) {
Type outputType = output.getType();
@@ -2143,14 +2159,17 @@ static void print(OpAsmPrinter &p, TiledLoopOp op) {
if (llvm::any_of(op.iterator_types(), [](Attribute attr) {
return attr.cast<StringAttr>().getValue() !=
getParallelIteratorTypeName();
- })) {
+ }))
p << " iterators" << op.iterator_types() << "";
- }
+
+ if (op.distribution_types().hasValue())
+ p << " distribution" << op.distribution_types().getValue() << "";
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(
op->getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(),
- getIteratorTypesAttrName()});
+ getIteratorTypesAttrName(),
+ getDistributionTypesAttrName()});
}
static ParseResult parseTiledLoopOp(OpAsmParser &parser,
@@ -2219,26 +2238,38 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
}
// Parse attributes.
- SmallVector<Attribute, 4> iterTypes;
- if (succeeded(parser.parseOptionalKeyword("iterators"))) {
- StringAttr iterType;
+ SmallVector<Attribute, 4> iterTypes, distributionTypes;
+ auto parseAttr = [&](StringRef keyword, SmallVector<Attribute, 4> *attrs) {
+ if (succeeded(parser.parseOptionalKeyword(keyword))) {
+ StringAttr attr;
- if (parser.parseLSquare() || parser.parseAttribute(iterType))
- return failure();
- iterTypes.push_back(iterType);
- for (int i = 1, e = ivs.size(); i < e; ++i) {
- if (parser.parseComma() || parser.parseAttribute(iterType))
+ if (parser.parseLSquare() || parser.parseAttribute(attr))
+ return failure();
+ attrs->push_back(attr);
+ for (int i = 1, e = ivs.size(); i < e; ++i) {
+ if (parser.parseComma() || parser.parseAttribute(attr))
+ return failure();
+ attrs->push_back(attr);
+ }
+ if (parser.parseRSquare())
return failure();
- iterTypes.push_back(iterType);
}
- if (parser.parseRSquare())
- return failure();
- } else {
+ return success();
+ };
+ if (failed(parseAttr("iterators", &iterTypes)) ||
+ failed(parseAttr("distribution", &distributionTypes)))
+ return failure();
+
+ // Set all loop iterator types to "parallel" if they are not printed in IR.
+ if (iterTypes.empty()) {
auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName());
iterTypes = SmallVector<Attribute, 4>(ivs.size(), parallelIter);
}
result.addAttribute(getIteratorTypesAttrName(),
builder.getArrayAttr(iterTypes));
+ if (!distributionTypes.empty())
+ result.addAttribute(getDistributionTypesAttrName(),
+ builder.getArrayAttr(distributionTypes));
result.addAttribute(
TiledLoopOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
@@ -2352,7 +2383,8 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
Location loc = tiledLoop.getLoc();
auto newTiledLoop = rewriter.create<TiledLoopOp>(
loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
- newInputs, tiledLoop.outputs(), tiledLoop.iterator_types());
+ newInputs, tiledLoop.outputs(), tiledLoop.iterator_types(),
+ tiledLoop.distribution_types());
// Clone the region.
BlockAndValueMapping bvm;
@@ -2441,7 +2473,8 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
Location loc = tiledLoop.getLoc();
auto newTiledLoop = rewriter.create<TiledLoopOp>(
loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
- tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types());
+ tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types(),
+ tiledLoop.distribution_types());
// Clone the region.
BlockAndValueMapping bvm;
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 1435f7b6bdb1b..211e97855c662 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -827,7 +827,8 @@ func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>,
%i2d_ = %input_2d: tensor<16x32xf32>,
%i1d_ = %input_1d: tensor<24xf32>)
outs(%o_ = %output: tensor<24xf32>)
- iterators["reduction", "parallel", "reduction"] {
+ iterators["reduction", "parallel", "reduction"]
+ distribution["block_x", "block_y", "none"] {
%sub_3d = subtensor %i3d_[%i, %j, %k][2, 4, 8][1, 1, 1]
: tensor<16x24x32xf32> to tensor<2x4x8xf32>
%sub_2d = subtensor %i2d_[%i, %k][2, 8][1, 1]
More information about the Mlir-commits
mailing list