[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

Abhishek Varma llvmlistbot at llvm.org
Mon Jun 3 01:57:00 PDT 2024

@@ -1267,25 +1260,198 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
 /// After fusing consumer into scf.forall we want to yield each of the resulting
 /// values by the tiled consumer within scf.forall.in_parallel region.
-static void
-fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
-                           SmallVector<Value> tiledResults,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
-                           ArrayRef<BlockArgument> bbArgs) {
+static void fixTerminatorSCFInParallel(
+    RewriterBase &rewriter, scf::ForallOp newForallOp, ResultRange tilingResult,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+    ArrayRef<BlockArgument> bbArgs) {
   scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
   Location firstYieldOpLoc =
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
         firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
+// If the top level loop of nested loop structure is scf.forall, need to create
+// additional tensor.extract_slice for its new appended `shared_outs` in order
+// to pass correct local memory for inner loops. E.g.
+// scf.forall shared_outs(%o1=..., %o2=...) {
+//     %local_o1 = extract_slice %o1
+//     // fix new appended `shared_out` %o2
+//     %local_o2 = extract_slice %o2
+//     scf.for init_args(%init1=%local_o1, %init2=%local_o2) {
+//        ...
+//     }
+//     ...
+// }
+static void
+fixSharedOutSCFForall(RewriterBase &rewriter, scf::ForallOp outerLoop,
+                      LoopLikeOpInterface innerLoop,
+                      SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+                      SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+                      unsigned newInitSize,
+                      SmallVector<tensor::ExtractSliceOp> &newExtractOps) {
+  rewriter.setInsertionPoint(innerLoop);
+  Location Loc = outerLoop.getLoc();
+  MutableArrayRef<BlockArgument> bbArgs = outerLoop.getBody()->getArguments();
+  SmallVector<tensor::ExtractSliceOp> newOps;
+  newOps.reserve(resultOffsets.size());
+  for (auto [bbArg, offset, sizes] : llvm::zip_equal(
+           bbArgs.take_back(newInitSize), resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        Loc, bbArg, offset, sizes, strides);
+    newOps.push_back(newExtractOp);
+  }
+  newExtractOps = newOps;
+// If outerMost loop of nested loop structure is `scf.forall`, need to deal with
+// DpsInit of tiled consumer
+static void fixDpsInitsOfTiledConsumer(
+    RewriterBase &rewriter, Operation *tiledConsumer,
+    ArrayRef<BlockArgument> bbArgs,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes) {
+  rewriter.setInsertionPoint(tiledConsumer);
+  Location Loc = tiledConsumer->getLoc();
+  for (auto &&[bbArg, offset, sizes, dpsInit] :
+       llvm::zip_equal(bbArgs, resultOffsets, resultSizes,
+                       cast<DestinationStyleOpInterface>(tiledConsumer)
+                           .getDpsInitsMutable())) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        Loc, bbArg, offset, sizes, strides);
+    dpsInit.set(newExtractOp.getResult());
+  }
+// compute all results tile by given SliceOp along operand
+static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
+    RewriterBase &rewriter, TilingInterface tilableOp, unsigned operandNumber,
+    OffsetSizeAndStrideOpInterface ossSliceOp,
+    SmallVector<SmallVector<OpFoldResult>> &allResultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &allResultSizes) {
+  // 1. check all stride all 1
+  if (llvm::any_of(ossSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(ossSliceOp, "ossSliceOp has stride");
+  }
+  // 2. compute iteration domain Tile from input position
Abhishek-Varma wrote:

// 2. compute iteration domain tile from the input position.


More information about the Mlir-commits mailing list