[Mlir-commits] [mlir] 52ffc72 - [mlir][tiling] Relax tiling to accept generating multiple operations.
Hanhan Wang
llvmlistbot at llvm.org
Fri Nov 4 13:59:43 PDT 2022
Author: Hanhan Wang
Date: 2022-11-04T13:59:24-07:00
New Revision: 52ffc728181bc2d3c889f7f80c252c3433b9e7b6
URL: https://github.com/llvm/llvm-project/commit/52ffc728181bc2d3c889f7f80c252c3433b9e7b6
DIFF: https://github.com/llvm/llvm-project/commit/52ffc728181bc2d3c889f7f80c252c3433b9e7b6.diff
LOG: [mlir][tiling] Relax tiling to accept generating multiple operations.
Some operations need to generate multiple operations when implementing
the tiling interface. Here is a sound example in IREE, see
https://github.com/iree-org/iree/pull/10905 for more details.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D137300
Added:
Modified:
mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 9fa4114c77b11..151993cc3d9a4 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -62,8 +62,10 @@ struct SCFTilingOptions {
/// Transformation information returned after tiling.
struct SCFTilingResult {
- /// The tiled operation generated.
- Operation *tiledOp;
+ /// Tiled operations that are generated during tiling. The order does not
+ /// matter except the last op. The replacements are expected to be the results
+ /// of the last op.
+ SmallVector<Operation *> tiledOps;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<scf::ForOp> loops;
/// Values to use as replacements for the untiled op. Is the same size as the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a35dd14483963..6b8ca9125c82d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -931,7 +931,7 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
if (failed(maybeTilingResult))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
- results.push_back(maybeTilingResult->tiledOp);
+ results.append(maybeTilingResult->tiledOps);
return DiagnosedSilenceableFailure(success());
}
@@ -1251,7 +1251,7 @@ transform::TileOp::apply(TransformResults &transformResults,
rewriter.replaceOp(linalgOp,
maybeTilingResult->loops.front()->getResults());
- tiled.push_back(maybeTilingResult->tiledOp);
+ tiled.append(maybeTilingResult->tiledOps);
for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
loops[en2.index()].push_back(en2.value());
}
@@ -1609,7 +1609,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
rewriter.replaceOp(tilingInterfaceOp, tilingResult->replacements);
- tiled.push_back(tilingResult->tiledOp);
+ tiled.append(tilingResult->tiledOps);
for (const auto &en2 : llvm::enumerate(tilingResult->loops))
loops[en2.index()].push_back(en2.value());
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0c86bd4d1262a..6e59bdb09b12d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -360,11 +360,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
tilingResult.loops.back().getBody()->getTerminator());
SmallVector<Operation *> tiledImplementation =
op.getTiledImplementation(rewriter, offsets, sizes);
- if (tiledImplementation.size() != 1) {
- return rewriter.notifyMatchFailure(
- op, "expected tiled implementation to return a single op");
- }
- tilingResult.tiledOp = tiledImplementation[0];
+ tilingResult.tiledOps.append(tiledImplementation);
if (op->getNumResults() == 0) {
// nothing more to do.
return tilingResult;
@@ -396,13 +392,13 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
}
FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues(
- rewriter, destinationTensors, tilingResult.tiledOp->getResults(),
+ rewriter, destinationTensors, tilingResult.tiledOps.back()->getResults(),
resultOffsetsList, resultSizesList, tilingResult.loops);
if (failed(replacementOr))
return rewriter.notifyMatchFailure(op, "failed to yield replacement");
if (auto dstOp =
- dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOp)) {
+ dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOps.back())) {
auto innerMostLoop = tilingResult.loops.back();
SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
assert(destinationTensors.size() ==
@@ -554,13 +550,14 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
if (failed(tilingResult))
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
- tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp);
+ for (auto tiledOp : tilingResult->tiledOps)
+ tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
tileAndFuseResult.loops = std::move(tilingResult->loops);
for (const auto &result : llvm::enumerate(
llvm::zip(consumer->getResults(), tilingResult->replacements))) {
tileAndFuseResult.replacements[std::get<0>(result.value())] =
std::get<1>(result.value());
- yieldedValueToResultNumber[tilingResult->tiledOp->getResult(
+ yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
result.index())] = result.index();
}
}
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 31e3c1a529a7c..1644179c427c3 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -193,7 +193,8 @@ struct TestTileUsingSCFForOp
rewriter.eraseOp(op);
}
- filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp);
+ for (auto tiledOp : tilingResult->tiledOps)
+ filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
return success();
}
More information about the Mlir-commits
mailing list