[Mlir-commits] [mlir] 8542794 - Relax requirements for TileOp.
Johannes Reifferscheid
llvmlistbot at llvm.org
Tue Jan 24 06:17:01 PST 2023
Author: Johannes Reifferscheid
Date: 2023-01-24T15:16:55+01:00
New Revision: 85427941e7e1790aca1aa6c7c3876f9b7571e30d
URL: https://github.com/llvm/llvm-project/commit/85427941e7e1790aca1aa6c7c3876f9b7571e30d
DIFF: https://github.com/llvm/llvm-project/commit/85427941e7e1790aca1aa6c7c3876f9b7571e30d.diff
LOG: Relax requirements for TileOp.
The op doesn't need to be a LinalgOp, implementing TilingInterface and
DestinationStyleOpInterace is sufficient.
Reviewed By: nicolasvasilache, ftynse
Differential Revision: https://reviews.llvm.org/D142460
Added:
Modified:
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8d06ef6cca0c5..4e0be5aa8dbd4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1902,20 +1902,21 @@ transform::TileOp::apply(TransformResults &transformResults,
SmallVector<Operation *> tiled;
SmallVector<SmallVector<Operation *, 4>, 4> loops;
loops.resize(getLoops().size());
- for (auto &en : llvm::enumerate(targets)) {
- auto linalgOp = dyn_cast<LinalgOp>(en.value());
- if (!linalgOp) {
- DiagnosedSilenceableFailure diag = emitSilenceableError()
- << "only linalg ops are supported";
- diag.attachNote(en.value()->getLoc()) << "target op";
+ for (auto &[i, op] : llvm::enumerate(targets)) {
+ auto tilingInterface = dyn_cast<TilingInterface>(op);
+ auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(op);
+ if (!tilingInterface || !dpsInterface) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "only ops implementing TilingInterface and "
+ "DestinationStyleOpInterface are supported";
+ diag.attachNote(op->getLoc()) << "target op";
return diag;
}
scf::SCFTilingOptions tilingOptions;
- unsigned index = en.index();
if (!tileSizes.empty()) {
- tilingOptions.setTileSizeComputationFunction([&, index](OpBuilder &b,
- Operation *) {
+ tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
+ Operation *) {
SmallVector<Value, 4> sizes;
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
@@ -1942,18 +1943,16 @@ transform::TileOp::apply(TransformResults &transformResults,
}
tilingOptions.setInterchange(getInterchange());
- TrivialPatternRewriter rewriter(linalgOp.getContext());
- FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
- rewriter, cast<TilingInterface>(linalgOp.getOperation()),
- tilingOptions);
+ TrivialPatternRewriter rewriter(op->getContext());
+ FailureOr<scf::SCFTilingResult> maybeTilingResult =
+ tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions);
if (failed(maybeTilingResult))
return DiagnosedSilenceableFailure::definiteFailure();
- if (linalgOp.hasBufferSemantics())
- rewriter.eraseOp(linalgOp);
+ if (dpsInterface.hasBufferSemantics())
+ rewriter.eraseOp(op);
else
- rewriter.replaceOp(linalgOp,
- maybeTilingResult->loops.front()->getResults());
+ rewriter.replaceOp(op, maybeTilingResult->loops.front()->getResults());
tiled.append(maybeTilingResult->tiledOps);
for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
More information about the Mlir-commits
mailing list