[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)
Abhishek Varma
llvmlistbot at llvm.org
Mon Jun 3 01:57:00 PDT 2024
================
@@ -1267,25 +1260,198 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
/// After fusing consumer into scf.forall we want to yield each of the resulting
/// values by the tiled consumer within scf.forall.in_parallel region.
-static void
-fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
- SmallVector<Value> tiledResults,
- ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
- ArrayRef<BlockArgument> bbArgs) {
+static void fixTerminatorSCFInParallel(
+ RewriterBase &rewriter, scf::ForallOp newForallOp, ResultRange tilingResult,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+ ArrayRef<BlockArgument> bbArgs) {
scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
Location firstYieldOpLoc =
(*(newTerminatorOp.getYieldingOps().begin())).getLoc();
for (auto [tiledResult, bbArg, resultOffset, resultSize] :
- llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+ llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> strides(resultOffset.size(),
rewriter.getIndexAttr(1));
rewriter.create<tensor::ParallelInsertSliceOp>(
firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
}
}
+// If the top level loop of nested loop structure is scf.forall, need to create
+// additional tensor.extract_slice for its new appended `shared_outs` in order
+// to pass correct local memory for inner loops. E.g.
+//
+// scf.forall shared_outs(%o1=..., %o2=...) {
+// %local_o1 = extract_slice %o1
+// // fix new appended `shared_out` %o2
+// %local_o2 = extract_slice %o2
+// scf.for init_args(%init1=%local_o1, %init2=%local_o2) {
+// ...
+// }
+// ...
+// }
+static void
+fixSharedOutSCFForall(RewriterBase &rewriter, scf::ForallOp outerLoop,
+ LoopLikeOpInterface innerLoop,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+ unsigned newInitSize,
+ SmallVector<tensor::ExtractSliceOp> &newExtractOps) {
+ rewriter.setInsertionPoint(innerLoop);
+ Location Loc = outerLoop.getLoc();
+ MutableArrayRef<BlockArgument> bbArgs = outerLoop.getBody()->getArguments();
+
+ SmallVector<tensor::ExtractSliceOp> newOps;
+ newOps.reserve(resultOffsets.size());
+ for (auto [bbArg, offset, sizes] : llvm::zip_equal(
+ bbArgs.take_back(newInitSize), resultOffsets, resultSizes)) {
+ SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+ auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+ Loc, bbArg, offset, sizes, strides);
+ newOps.push_back(newExtractOp);
+ }
+ newExtractOps = newOps;
+}
+
+// If outerMost loop of nested loop structure is `scf.forall`, need to deal with
+// DpsInit of tiled consumer
+static void fixDpsInitsOfTiledConsumer(
+ RewriterBase &rewriter, Operation *tiledConsumer,
+ ArrayRef<BlockArgument> bbArgs,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes) {
+ rewriter.setInsertionPoint(tiledConsumer);
+ Location Loc = tiledConsumer->getLoc();
+ for (auto &&[bbArg, offset, sizes, dpsInit] :
+ llvm::zip_equal(bbArgs, resultOffsets, resultSizes,
+ cast<DestinationStyleOpInterface>(tiledConsumer)
+ .getDpsInitsMutable())) {
+ SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+ auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+ Loc, bbArg, offset, sizes, strides);
+ dpsInit.set(newExtractOp.getResult());
+ }
+}
+
+// compute all results tile by given SliceOp along operand
+static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
+ RewriterBase &rewriter, TilingInterface tilableOp, unsigned operandNumber,
+ OffsetSizeAndStrideOpInterface ossSliceOp,
+ SmallVector<SmallVector<OpFoldResult>> &allResultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &allResultSizes) {
+ // 1. check all stride all 1
+ if (llvm::any_of(ossSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(ossSliceOp, "ossSliceOp has stride");
+ }
+ // 2. compute iteration domain Tile from input position
----------------
Abhishek-Varma wrote:
```suggestion
// 2. compute iteration domain tile from the input position.
```
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list