[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to single nested `scf.for` (PR #108318)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 11 20:00:36 PDT 2024
================
@@ -1754,79 +1725,108 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
if (failed(tileAndFuseResult)) {
return failure();
}
- rewriter.replaceAllUsesWith(
- tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
- clonedInsertSliceOp.getSource());
-
- // 8 - Extract offset/sizes/strides required to create the
- // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
- SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
-
- // 9. Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return rewriter.notifyMatchFailure(
- candidateSliceOp, "containingOp's result yield with stride");
- }
+ auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
+ rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
+ clonedInsertSliceOp.getSource());
- // 10. Try to get iter domain position from input position.
- SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
- if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
- rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
- iterDomainSizes))) {
- return rewriter.notifyMatchFailure(
- clonedConsumerOp, "can't get iter domain position from input position");
- }
+ // 7. Reconstruct [nested] loop with new inits.
+ YieldTiledValuesFn newYieldValuesFn =
+ [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+ ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+ SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+ SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+ OpBuilder::InsertionGuard g(innerRewriter);
+ // 8. Set inner insertPoint right before tiled consumer op.
+ innerRewriter.setInsertionPoint(tiledConsumerOp);
- // 11. Try to fetch the offset and size for all results of the cloned
- // consumer. This would then be used to form the corresponding
- // tensor.insert_slice/parallel_insert_slice later.
- unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
- SmallVector<SmallVector<OpFoldResult>> resultOffsets(
- totalNumResultsOfConsumer);
- SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
- for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
- if (failed(clonedConsumerOp.getResultTilePosition(
- rewriter, idx, iterDomainOffsets, iterDomainSizes,
- resultOffsets[idx], resultSizes[idx]))) {
+ SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+
+ // 9. Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
return rewriter.notifyMatchFailure(
- clonedConsumerOp,
- "can't get result domain position from iter domain position");
+ candidateSliceOp, "containingOp's result yield with stride");
}
- }
- auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
- auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
- if (isInsertSliceOp) {
- auto newForOp = cast<scf::ForOp>(newLoopOp);
- fixTerminatorSCFYield(
- rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
- newForOp.getBody()->getArguments().drop_front(1 + initSize));
- } else {
- auto newForallOp = cast<scf::ForallOp>(newLoopOp);
- fixTerminatorSCFInParallel(
- rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
- arrayRefOffsets, arrayRefSizes,
- newForallOp.getBody()->getArguments().drop_front(rank + initSize));
- }
+ // 10. Try to get iter domain position from input position.
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ tiledConsumerOp,
+ "can't get iter domain position from input position");
+ }
- // 12. Replace the result of scf loop and consumer op with new loop's results.
- for (auto &&[oldResult, newResult] :
- llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
- rewriter.replaceAllUsesWith(oldResult, newResult);
+ // 11. Try to fetch the offset and size for all results of the cloned
+ // consumer. This would then be used to form the corresponding
+ // tensor.insert_slice/parallel_insert_slice later.
+ unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+ totalNumResultsOfConsumer);
+ SmallVector<SmallVector<OpFoldResult>> resultSizes(
+ totalNumResultsOfConsumer);
+ for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
+ if (failed(tiledConsumerOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultOffsets[idx], resultSizes[idx]))) {
+ return rewriter.notifyMatchFailure(
+ tiledConsumerOp,
+ "can't get result domain position from iter domain position");
+ }
+ }
+
+ // 12. Create `extract_slice` for `iter_args` for DPS operation if
+ // necessary.
+ if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
+ tiledConsumerOp.getOperation())) {
+ rewriter.setInsertionPoint(tiledDestStyleOp);
+ for (const auto &&[index, newRegionArg] :
+ llvm::enumerate(newRegionIterArgs)) {
+ auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
+ loc, newRegionArg, resultOffsets[index], resultSizes[index],
+ SmallVector<OpFoldResult>(resultOffsets[index].size(),
+ rewriter.getIndexAttr(1)));
+ // Make C++ 17 happy, otherwise it will throw error `captured structured
+ // bindings are a C++20 extension`.
----------------
Yun-Fly wrote:
Ok, changed.
https://github.com/llvm/llvm-project/pull/108318
More information about the Mlir-commits
mailing list