[Mlir-commits] [mlir] Extend `TilingInterface` to allow more flexible tiling (PR #95422)
Srinath Avadhanula
llvmlistbot at llvm.org
Fri Jun 14 04:49:24 PDT 2024
https://github.com/srinathava updated https://github.com/llvm/llvm-project/pull/95422
>From 18ddecd8ab7738d44448beb3aa81b9db4f4cd6f2 Mon Sep 17 00:00:00 2001
From: Srinath Avadhanula <srinath.avadhanula at getcruise.com>
Date: Thu, 13 Jun 2024 08:29:31 -0700
Subject: [PATCH 1/2] initial commit
---
.../SCF/Transforms/TileUsingInterface.h | 2 +
.../include/mlir/Interfaces/TilingInterface.h | 4 ++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 8 +++-
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 11 +++++-
.../SCF/Transforms/TileUsingInterface.cpp | 37 ++++++++++---------
5 files changed, 41 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index dac79111af3c9..fecd33193eb0d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,6 +85,7 @@ struct SCFTilingResult {
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
+ SmallVector<Operation *> extractSliceOps;
};
/// Method to tile an op that implements the `TilingInterface` using
@@ -135,6 +136,7 @@ struct SCFFuseProducerOfSliceResult {
OpResult origProducer; // Original untiled producer.
Value tiledAndFusedProducer; // Tile and fused producer value.
SmallVector<Operation *> tiledOps;
+ SmallVector<Operation *> extractSliceOps;
};
std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index ca570490ccf5b..e5ed016d53fc1 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -28,9 +28,13 @@ namespace mlir {
/// are returned to the caller for further transformations.
/// - `tiledValues` contains the tiled value corresponding to the result of the
/// untiled operation.
+/// - `extractSliceOps` contains all the `tensor.extract_slice` ops used in
+/// generating the `tiledOps`. Usually these are operands to the `tiledOps`
+/// but they can be embedded in regions owned by `tiledOps`.
struct TilingResult {
SmallVector<Operation *> tiledOps;
SmallVector<Value> tiledValues;
+ SmallVector<Operation *> extractSliceOps;
};
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..5198e0bceaa6e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2501,7 +2501,13 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ SmallVector<Operation *> sliceOps;
+ for (Value operand : tiledOperands)
+ if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
+ sliceOps.push_back(sliceOp);
+
+ return TilingResult{
+ {tiledOp}, SmallVector<Value>(tiledOp->getResults()), sliceOps};
}
LogicalResult SoftmaxOp::getResultTilePosition(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c3ab3cecfada7..f25ccc38ba0a3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -129,7 +129,13 @@ struct LinalgOpTilingInterface
Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ SmallVector<Operation *> sliceOps;
+ for (Value operand : tiledOperands)
+ if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
+ sliceOps.push_back(sliceOp);
+
+ return TilingResult{
+ {tiledOp}, SmallVector<Value>(tiledOp->getResults()), sliceOps};
}
/// Utility to fetch the offsets and sizes when applied as per the indexing
@@ -247,7 +253,8 @@ struct LinalgOpTilingInterface
return TilingResult{
tilingResult->tiledOps,
- SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
+ SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
+ tilingResult->extractSliceOps};
}
/// Method to generate the tiled implementation of an operation from the tile
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..fb3ec2a5fa0a8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -619,7 +619,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
if (llvm::all_of(tileSizes, isZeroIndex)) {
tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
tilingResult =
- TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
+ TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
+ /*extractSliceOps=*/{}};
return success();
}
@@ -675,12 +676,14 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// op.
if (loops.empty()) {
return scf::SCFTilingResult{tilingResult->tiledOps, loops,
- tilingResult->tiledValues};
+ tilingResult->tiledValues,
+ tilingResult->extractSliceOps};
}
SmallVector<Value> replacements = llvm::map_to_vector(
loops.front()->getResults(), [](OpResult r) -> Value { return r; });
- return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
+ return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+ tilingResult->extractSliceOps};
}
FailureOr<scf::SCFReductionTilingResult>
@@ -931,9 +934,9 @@ mlir::scf::tileAndFuseProducerOfSlice(
->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
.set(origDestinationTensors[resultNumber]);
}
- return scf::SCFFuseProducerOfSliceResult{fusableProducer,
- tileAndFuseResult->tiledValues[0],
- tileAndFuseResult->tiledOps};
+ return scf::SCFFuseProducerOfSliceResult{
+ fusableProducer, tileAndFuseResult->tiledValues[0],
+ tileAndFuseResult->tiledOps, tileAndFuseResult->extractSliceOps};
}
/// Reconstruct the fused producer from within the tiled-and-fused code.
@@ -962,13 +965,12 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
.getDefiningOp<DestinationStyleOpInterface>()) {
rewriter.setInsertionPoint(tiledDestStyleOp);
Value newRegionArg = newRegionIterArgs.back();
- auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
- sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
unsigned resultNumber = fusableProducer.getResultNumber();
- rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
- tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
- });
+ auto origSlice = tiledDestStyleOp.getDpsInits()[resultNumber]
+ .getDefiningOp<tensor::ExtractSliceOp>();
+ if (origSlice) {
+ origSlice.getSourceMutable().set(newRegionArg);
+ }
}
Block *block = rewriter.getInsertionPoint()->getBlock();
rewriter.setInsertionPoint(block->getTerminator());
@@ -1036,15 +1038,14 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
// operations. If the producers of the source of the `tensor.extract_slice`
// can be tiled such that the tiled value is generated in-place, that
// effectively tiles + fuses the operations.
- auto addCandidateSlices = [](Operation *fusedOp,
+ auto addCandidateSlices = [](const SmallVector<Operation *> &newSliceOps,
std::deque<tensor::ExtractSliceOp> &candidates) {
- for (Value operand : fusedOp->getOperands())
- if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
- candidates.push_back(sliceOp);
+ for (auto *op : newSliceOps)
+ candidates.push_back(llvm::cast<tensor::ExtractSliceOp>(op));
};
std::deque<tensor::ExtractSliceOp> candidates;
- addCandidateSlices(tiledAndFusedOps.back(), candidates);
+ addCandidateSlices(tilingResult->extractSliceOps, candidates);
OpBuilder::InsertionGuard g(rewriter);
while (!candidates.empty()) {
// Traverse the slices in BFS fashion.
@@ -1086,7 +1087,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
tiledAndFusedOps.insert(tiledAndFusedOp);
- addCandidateSlices(tiledAndFusedOp, candidates);
+ addCandidateSlices(fusedResult->extractSliceOps, candidates);
}
}
>From af5f7a5b21af2137da0598b41b7c8c032b89a264 Mon Sep 17 00:00:00 2001
From: Srinath Avadhanula <srinath.avadhanula at getcruise.com>
Date: Fri, 14 Jun 2024 04:48:24 -0700
Subject: [PATCH 2/2] also add extractSliceOps to TensorTilingInterfaceImpl
---
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 36 +++++++++++++++----
1 file changed, 29 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 9b2a97eb2b006..33db5a5f043f3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -99,6 +99,16 @@ static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
applyPermutationToVector<OpFoldResult>(sizes, permutation);
}
+static SmallVector<Operation *> sliceOperandsOf(Operation *op) {
+ SmallVector<Operation *> sliceOps;
+ for (auto operand : op->getOperands()) {
+ if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) {
+ sliceOps.push_back(sliceOp);
+ }
+ }
+ return sliceOps;
+}
+
struct PackOpTiling
: public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
@@ -192,7 +202,8 @@ struct PackOpTiling
loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
return TilingResult{{tiledPackOp},
- SmallVector<Value>(tiledPackOp->getResults())};
+ SmallVector<Value>(tiledPackOp->getResults()),
+ sliceOperandsOf(tiledPackOp)};
}
LogicalResult
@@ -440,12 +451,16 @@ struct UnPackOpTiling
if (isPerfectTilingCase)
return TilingResult{{tiledUnpackOp},
- SmallVector<Value>(tiledUnpackOp->getResults())};
+ SmallVector<Value>(tiledUnpackOp->getResults()),
+ sliceOperandsOf(tiledUnpackOp)};
auto extractSlice =
b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
resultOffsetsFromDest, sizes, destStrides);
- return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}};
+
+ return TilingResult{{tiledUnpackOp},
+ {extractSlice.getResult()},
+ sliceOperandsOf(tiledUnpackOp)};
}
LogicalResult
@@ -567,7 +582,8 @@ struct UnPackOpTiling
tiledOperands, op->getAttrs());
return TilingResult{{tiledUnPackOp},
- SmallVector<Value>(tiledUnPackOp->getResults())};
+ SmallVector<Value>(tiledUnPackOp->getResults()),
+ sliceOperandsOf(tiledUnPackOp)};
}
};
@@ -756,7 +772,9 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
// the original data source x is not used.
if (hasZeroLen) {
Operation *generateOp = createGenerateOp();
- return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}};
+ return TilingResult{{generateOp},
+ {castResult(generateOp->getResult(0))},
+ /*extractSliceOps=*/{}};
}
// If there are dynamic dimensions: Generate an scf.if check to avoid
@@ -776,11 +794,15 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
elseOp = createPadOfExtractSlice();
b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
});
- return TilingResult{{elseOp}, SmallVector<Value>(result->getResults())};
+ return TilingResult{{elseOp},
+ SmallVector<Value>(result->getResults()),
+ sliceOperandsOf(elseOp)};
}
Operation *newPadOp = createPadOfExtractSlice();
- return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}};
+ return TilingResult{{newPadOp},
+ {castResult(newPadOp->getResult(0))},
+ sliceOperandsOf(newPadOp)};
}
void mlir::tensor::registerTilingInterfaceExternalModels(
More information about the Mlir-commits
mailing list