[Mlir-commits] [mlir] [mlir][linalg] Enable fuse consumer (PR #85528)
donald chen
llvmlistbot at llvm.org
Wed Apr 17 03:04:16 PDT 2024
https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/85528
>From b7a0a8575e4760fe21e9f4f3f7fca36da4339b6e Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Sat, 16 Mar 2024 09:12:12 +0800
Subject: [PATCH] [mlir][linalg] Enable fuse consumer
This patch adds support for consumer fusion to the tiling interface.
- Add interface method 'getIterationDomainTileFromOperandTile' to tiling
interface which get iteration domain position from operand position.
- Add interface method 'getTiledImplementationFromOperandTile' to tiling
interface which generate tiled implementation according to operand position.
---
.../mlir/Interfaces/TilingInterface.td | 55 +++++++++++
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 95 +++++++++++++++----
2 files changed, 129 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..9fff3465ad9e72 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -96,6 +96,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return the position of iteration domain tile computed by the
+ tiled operation.
+ }],
+ /*retType=*/"LogicalResult",
+ /*methodName=*/"getIterationDomainTileFromOperandTile",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult> ":$offsets,
+ "ArrayRef<OpFoldResult> ":$sizes,
+ "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+ "SmallVector<OpFoldResult> &":$iterDomainSizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Method to generate the code that produces a tile of the result.
@@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to generate the tiled implementation of an operation from
+ operand tile position.
+
+ Generates the IR that computes the tiled implementation of an
+ operation from operand tile. The `offsets` and `sizes`
+ describe the tile of the operand required. This is different from
+ `getTiledImplementation` which generates the tiled
+ implementation of the operation given a tile of the
+ iteration space. This method generates a tiled
+ implementation of the operation based on the tile of the
+ operand required. This method enables fusion consumer by using
+ tile and fuse. The method returns failure if the operation
+ can't be tiled to generate the operand tile. In practical terms
+ this implies it cannot be tiled and fused with its producers.
+
+ - `offsets` provides the offset of the tile in the coordinate system
+ of the original iteration space, i.e., if an iteration space
+ dimension had non-zero offset, it must be included in the offset
+ provided here (as opposed to zero-based offset "relative" to the
+ iteration space).
+ - `sizes` provides the size of the tile.
+ }],
+ /*retType=*/"FailureOr<TilingResult>",
+ /*methodName=*/"getTiledImplementationFromOperandTile",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult>":$offsets,
+ "ArrayRef<OpFoldResult>":$sizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Generates the scalar implementation of the operation.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..094ae81d2a01fc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -132,6 +132,58 @@ struct LinalgOpTilingInterface
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
+ void getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
+ AffineMap indexingMap,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &mappedOffsets,
+ SmallVector<OpFoldResult> &mappedSizes) const {
+ auto numLoops = linalgOp.getNumLoops();
+ auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
+ mappedOffsets.resize(numLoops);
+ mappedSizes.resize(numLoops);
+ if (!indexingMap.isPermutation()) {
+ SmallVector<Range> iterationDomain =
+ tilingInterfaceOp.getIterationDomain(b);
+ for (const auto &range : llvm::enumerate(iterationDomain)) {
+ mappedOffsets[range.index()] = range.value().offset;
+ mappedSizes[range.index()] = range.value().size;
+ }
+ }
+ for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
+ unsigned dimPosition =
+ cast<AffineDimExpr>(resultExpr.value()).getPosition();
+ mappedOffsets[dimPosition] = offsets[resultExpr.index()];
+ mappedSizes[dimPosition] = sizes[resultExpr.index()];
+ }
+ }
+
+ // Return the details of the output tile generated by the tiled
+ // implementation.
+ LogicalResult getIterationDomainTileFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &iterDomainOffsets,
+ SmallVector<OpFoldResult> &iterDomainSizes) const {
+ auto linalgOp = cast<LinalgOp>(op);
+
+ // Check that the indexing map used for the operand is a projected
+ // permutation. This could be relaxed with a more general approach that can
+ // map the offsets and sizes from the operand to iteration space tiles
+ // (filling in full extent for dimensions not used to access the result).
+ AffineMap indexingMap =
+ linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+ if (!indexingMap.isProjectedPermutation()) {
+ return op->emitOpError(
+ "unhandled get iter domain position when operand is not "
+ "accessed using a permuted projection");
+ }
+
+ getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+ iterDomainOffsets, iterDomainSizes);
+ return success();
+ }
+
// Return the details of the output tile generated by the tiled
// implementation.
LogicalResult
@@ -160,6 +212,20 @@ struct LinalgOpTilingInterface
return success();
}
+ FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
+ if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+ b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+ return op->emitOpError(
+ "unable to obtain the iter domain position of the operation.");
+ }
+ return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+ mappedSizes);
+ }
+
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +243,16 @@ struct LinalgOpTilingInterface
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}
-
- auto numLoops = linalgOp.getNumLoops();
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+ mappedOffsets, mappedSizes);
auto tilingInterfaceOp = cast<TilingInterface>(op);
- SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
- iterationTileSizes(numLoops);
- if (!indexingMap.isPermutation()) {
- SmallVector<Range> iterationDomain =
- tilingInterfaceOp.getIterationDomain(b);
- for (const auto &range : llvm::enumerate(iterationDomain)) {
- iterationTileOffsets[range.index()] = range.value().offset;
- iterationTileSizes[range.index()] = range.value().size;
- }
- }
- for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
- unsigned dimPosition =
- cast<AffineDimExpr>(resultExpr.value()).getPosition();
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
- iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
- }
-
FailureOr<TilingResult> tilingResult =
- tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
- iterationTileSizes);
+ tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+ if (failed(tilingResult))
+ return failure();
+
if (tilingResult->tiledOps.size() != 1)
return op->emitOpError("failed to generate tiled implementation");
More information about the Mlir-commits
mailing list