[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to single nested `scf.for` (PR #94190)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 10 09:07:05 PDT 2024
================
@@ -1754,79 +1723,105 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
if (failed(tileAndFuseResult)) {
return failure();
}
- rewriter.replaceAllUsesWith(
- tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
- clonedInsertSliceOp.getSource());
-
- // 8 - Extract offset/sizes/strides required to create the
- // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
- SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
-
- // 9. Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return rewriter.notifyMatchFailure(
- candidateSliceOp, "containingOp's result yield with stride");
- }
+ auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
+ rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
+ clonedInsertSliceOp.getSource());
- // 10. Try to get iter domain position from input position.
- SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
- if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
- rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
- iterDomainSizes))) {
- return rewriter.notifyMatchFailure(
- clonedConsumerOp, "can't get iter domain position from input position");
- }
+ // 7 - Reconstruct [nested] loop with new inits.
+ YieldTiledValuesFn newYieldValuesFn =
+ [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+ ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+ SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+ SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+ OpBuilder::InsertionGuard g(innerRewriter);
+ // 9. Set inner insertPoint right before tiled consumer op.
----------------
MaheshRavishankar wrote:
Nit: this is 8.
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list