[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 3 00:43:29 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: None (Yun-Fly)

<details>
<summary>Changes</summary>

Hi, based on early discussion in [this thread](https://github.com/llvm/llvm-project/pull/88712#discussion_r1590568717). This patch aims to extend new feature of fusing consumer to more complex nested loop structure. E.g.

```
#map = affine_map<(d0) -> (d0 * 128)>
module {
  func.func @<!-- -->fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
    %c0 = arith.constant 0 : index
    %c64 = arith.constant 64 : index
    %c128 = arith.constant 128 : index
    %cst = arith.constant 0.000000e+00 : f32
    %dest0 = tensor.empty() : tensor<256x256xf32>
    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
      %iv0 = affine.apply #map(%arg3)
      %iv1 = affine.apply #map(%arg4)
      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
          scf.yield %insert_slice : tensor<128x128xf32>
        }
        scf.yield %3 : tensor<128x128xf32>
      }
      scf.forall.in_parallel {
         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
      }
    }
    %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
    return %5 : tensor<256x256xf32>
  }
}
```

#### What's New in this PR:
1. support nested loop structure, including both `scf.for` and `scf.forall`.
2. support multi-level `insert_slice` or `parallel_insert_slice`.


NOTE that: this PR DOES NOT deal with the refactor of `getTiledImplementation` we have talked before but just focuses on the functionality enhancement, BTW, in above example, you can also find that the similar issue related to unmatched semantic between tiled operand and assumption of current `getTiledImplementation` even on `dpsInits`. To unblock this necessary patch, I temporarily follow the method as @<!-- -->MaheshRavishankar suggested, using dummy `insert_slice` to align those gap.

The resulting IR will finally appear like below:
```
#map = affine_map<(d0) -> (d0 * 128)>
#map1 = affine_map<(d0, d1) -> (d0 + d1 * 128)>
module {
  module {
    func.func @<!-- -->fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
      %c0 = arith.constant 0 : index
      %c64 = arith.constant 64 : index
      %c128 = arith.constant 128 : index
      %cst = arith.constant 0.000000e+00 : f32
      %0 = tensor.empty() : tensor<256x256xf32>
      %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
      %2:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %1, %arg6 = %0) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
        %3 = affine.apply #map(%arg3)
        %4 = affine.apply #map(%arg4)
        %extracted_slice = tensor.extract_slice %arg5[%3, %4] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
        %extracted_slice_0 = tensor.extract_slice %arg0[%3, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
        %extracted_slice_1 = tensor.extract_slice %arg1[0, %4] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
        %extracted_slice_2 = tensor.extract_slice %arg6[%3, %4] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
        %5:2 = scf.for %arg7 = %c0 to %c128 step %c64 iter_args(%arg8 = %extracted_slice, %arg9 = %extracted_slice_2) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
          %6:2 = scf.for %arg10 = %c0 to %c128 step %c64 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
            %extracted_slice_3 = tensor.extract_slice %arg11[%arg7, %arg10] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
            %extracted_slice_4 = tensor.extract_slice %extracted_slice_0[%arg7, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
            %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[0, %arg10] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
            %7 = linalg.matmul ins(%extracted_slice_4, %extracted_slice_5 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
            %8 = affine.apply #map1(%arg7, %arg3)
            %9 = affine.apply #map1(%arg10, %arg4)
            %extracted_slice_6 = tensor.extract_slice %arg2[%8, %9] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
            %extracted_slice_7 = tensor.extract_slice %arg12[%arg7, %arg10] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
            %10 = linalg.add ins(%7, %extracted_slice_6 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%extracted_slice_7 : tensor<64x64xf32>) -> tensor<64x64xf32>
            %inserted_slice = tensor.insert_slice %7 into %arg11[%arg7, %arg10] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
            %inserted_slice_8 = tensor.insert_slice %10 into %arg12[%arg7, %arg10] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
            scf.yield %inserted_slice, %inserted_slice_8 : tensor<128x128xf32>, tensor<128x128xf32>
          }
          scf.yield %6#<!-- -->0, %6#<!-- -->1 : tensor<128x128xf32>, tensor<128x128xf32>
        }
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %5#<!-- -->1 into %arg6[%3, %4] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
          tensor.parallel_insert_slice %5#<!-- -->0 into %arg5[%3, %4] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
        }
      }
      return %2#<!-- -->1 : tensor<256x256xf32>
    }
  }
}
```

Looking forward to your suggestion and review, thanks.

---

Patch is 46.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94190.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+511-225) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+96) 


``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..9dd730e64a030 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -1103,98 +1104,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
 // tileAndFuseConsumerUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
-/// A utility function that checks whether the only use of the result of a
-/// tensor.insert_slice op is in a scf.yield op.
-static LogicalResult
-checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
-  Value result = candidateSliceOp.getResult();
-  Value::use_range uses = result.getUses();
-  if (!llvm::hasSingleElement(uses)) {
-    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
-    return failure();
-  }
-  OpOperand &operandUse = (*uses.begin());
-  Operation *userOp = operandUse.getOwner();
-  if (!isa<scf::YieldOp>(userOp)) {
-    LLVM_DEBUG(llvm::dbgs()
-               << "Expected scf.yield to be the only user, but got -> "
-               << (*userOp));
-    return failure();
-  }
-  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
-    LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
-                               "be in the same block\n");
-    return failure();
-  }
-  return success();
-}
-
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
-static FailureOr<OpOperand *> getConsumerFromUses(Value val,
-                                                  Block *containingOpBlock) {
-  // Step 1. Check that the value has exactly one use.
-  if (!llvm::hasSingleElement(val.getUses()))
-    return failure();
-  // Step 2. Get uses.
-  OpOperand &operand = (*val.getUses().begin());
-  Operation *consumerOp = operand.getOwner();
-  // TODO: We have to init result of consumer before scf.for, use
-  //       DestinationStyleOpInterface to get result shape from init for now.
-  //       Add support for other op such as op has InferTypeOpInterface.
-  if (!isa<TilingInterface>(consumerOp) ||
-      !isa<DestinationStyleOpInterface>(consumerOp))
-    return failure();
-  if (containingOpBlock != consumerOp->getBlock())
-    return failure();
-  return &operand;
-}
-
-/// Fetch the untiled consumer of a scf.for's result which is yielded by a
-/// tensor.insert_slice. This function makes the following assumptions :
-/// 1.  tensor.insert_slice has scf.yield as its only user.
-/// 2.  scf.for's corresponding result has only one use.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
-  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
-    return failure();
-  Value sliceResult = candidateSliceOp.getResult();
-  // Step 1. Fetch the corresponding output.
-  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
-  unsigned resultNumber = yieldOpOperand.getOperandNumber();
-  // Step 2. Check containing op is scf.for.
-  Operation *containingOp = candidateSliceOp->getParentOp();
-  auto forOp = dyn_cast<scf::ForOp>(containingOp);
-  if (!forOp)
-    return failure();
-  Value resultingValue = forOp->getResult(resultNumber);
-
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
-/// Fetch the first untiled consumer of a scf.forall's result which is yielded
-/// by a tensor.parallel_insert_slice.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
-  // Step 1. Fetch the corresponding output
-  Value sliceDest = candidateSliceOp.getDest();
-  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
-  if (!iterArg)
-    return failure();
-  Operation *containingOp = iterArg.getOwner()->getParentOp();
-  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
-    return failure();
-  // Step 2. Check that the containing op is scf.forall.
-  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
-  if (!forallOp)
-    return failure();
-  Value resultingValue =
-      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
-
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
 /// This utility currently checks whether the loop either :-
 /// 1. Yields exactly one result.
 /// 2. Has consumer op as its first user and other users to be in the same
@@ -1220,31 +1129,116 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
   return success();
 }
 
-/// A utility to fetch an untiled consumer of
-/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
-  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(insertSlice);
-  } else if (auto parallelInsertSlice =
-                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(parallelInsertSlice);
-  } else {
+// Traverse and collect all outer loops of given sliceOp, sorted by
+// outer-to-inner. If `untilLoop` found, stop walk through in advance.
+static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
+    OffsetSizeAndStrideOpInterface sliceOp,
+    std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
+  SmallVector<LoopLikeOpInterface> outerLoops;
+  auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
+  while (forOp) {
+    outerLoops.push_back(forOp);
+    if (untilLoop.has_value() && *untilLoop == forOp)
+      break;
+    forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+  }
+  return {outerLoops.rbegin(), outerLoops.rend()};
+}
+
+// Get the Result of top-level Loop which yield the target InsertSliceOp. E.g
+// ```
+// %1 = scf.for
+//  %2 = scf.for
+//   %3 = scf.for
+//      ...
+//      %4 = insert
+//      yield %4
+//   %5 = insert %3
+//   yield %5
+//  yield %2
+// ```
+// @param targetSliceOp: %4 = insert
+// @return Result Value: %1
+//         Collected insertSliceOp List during walk including targetSliceOp:
+//                %4 = insert and %5 = insert %3
+static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+getResultOfTopLevelLoopYieldInsertSliceOp(
+    OffsetSizeAndStrideOpInterface targetSliceOp, int curDepth = 0,
+    int maxDepth = 5) {
+  // control recursive time in avoid of stack overflow
+  if (curDepth > maxDepth)
+    return failure();
+
+  SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
+  candidateSliceOpList.push_back(targetSliceOp);
+  Value resultOfLoop;
+  if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(
+          targetSliceOp.getOperation())) {
+    Value destValue = sliceOp.getDest();
+    auto iterArg = cast<BlockArgument>(destValue);
+    auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
+    if (!forallOp)
+      return failure();
+    resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+  } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(
+                 targetSliceOp.getOperation())) {
+    Value resultValue = sliceOp.getResult();
+    for (auto &useOperand : resultValue.getUses()) {
+      if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        if (llvm::detail::isPresent(resultOfLoop))
+          return failure();
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+      }
+    }
+  }
+
+  if (!llvm::detail::isPresent(resultOfLoop))
     return failure();
+
+  while (true) {
+    bool walkThroughOuterLoop = false;
+    for (auto &useOperand : resultOfLoop.getUses()) {
+      if (auto sliceOp =
+              dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
+        auto resultAndSliceOpsPair =
+            getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp, curDepth + 1);
+        if (failed(resultAndSliceOpsPair))
+          return failure();
+        candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
+                                    (*resultAndSliceOpsPair).second.end());
+        return std::make_pair((*resultAndSliceOpsPair).first,
+                              candidateSliceOpList);
+      } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        // walk through outer loop
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+        walkThroughOuterLoop = true;
+        break;
+      }
+    }
+    if (!walkThroughOuterLoop)
+      break;
   }
+  return std::make_pair(resultOfLoop, candidateSliceOpList);
 }
 
 /// After fusing consumer into scf.for we want to modify the scf.yield operation
 /// to reflect the same by returning the values yielded by the tiled consumer.
 static void
 fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
-                      TilingResult &tilingResult,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                      ResultRange tilingResult,
+                      SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+                      SmallVector<SmallVector<OpFoldResult>> &resultSizes,
                       ArrayRef<BlockArgument> bbArgs) {
   scf::YieldOp oldTerminatorOp =
       cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
   unsigned totalOldResults = oldTerminatorOp->getNumResults();
-  unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+  unsigned totalTiledResults = tilingResult.size();
   SmallVector<Value> newYieldOperands;
   newYieldOperands.reserve(totalOldResults + totalTiledResults);
   for (auto oldResult : oldTerminatorOp.getResults()) {
@@ -1253,8 +1247,7 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
   rewriter.setInsertionPointAfter(oldTerminatorOp);
   Location loc = newForOp.getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
-                       resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
@@ -1267,18 +1260,17 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
 
 /// After fusing consumer into scf.forall we want to yield each of the resulting
 /// values by the tiled consumer within scf.forall.in_parallel region.
-static void
-fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
-                           SmallVector<Value> tiledResults,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
-                           ArrayRef<BlockArgument> bbArgs) {
+static void fixTerminatorSCFInParallel(
+    RewriterBase &rewriter, scf::ForallOp newForallOp, ResultRange tilingResult,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+    ArrayRef<BlockArgument> bbArgs) {
   scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
   rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
   Location firstYieldOpLoc =
       (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     rewriter.create<tensor::ParallelInsertSliceOp>(
@@ -1286,6 +1278,180 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
   }
 }
 
+// If the top level loop of nested loop structure is scf.forall, need to create
+// additional tensor.extract_slice for its new appended `shared_outs` in order
+// to pass correct local memory for inner loops. E.g.
+//
+// scf.forall shared_outs(%o1=..., %o2=...) {
+//     %local_o1 = extract_slice %o1
+//     // fix new appended `shared_out` %o2
+//     %local_o2 = extract_slice %o2
+//     scf.for init_args(%init1=%local_o1, %init2=%local_o2) {
+//        ...
+//     }
+//     ...
+// }
+static void
+fixSharedOutSCFForall(RewriterBase &rewriter, scf::ForallOp outerLoop,
+                      LoopLikeOpInterface innerLoop,
+                      SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+                      SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+                      unsigned newInitSize,
+                      SmallVector<tensor::ExtractSliceOp> &newExtractOps) {
+  rewriter.setInsertionPoint(innerLoop);
+  Location Loc = outerLoop.getLoc();
+  MutableArrayRef<BlockArgument> bbArgs = outerLoop.getBody()->getArguments();
+
+  SmallVector<tensor::ExtractSliceOp> newOps;
+  newOps.reserve(resultOffsets.size());
+  for (auto [bbArg, offset, sizes] : llvm::zip_equal(
+           bbArgs.take_back(newInitSize), resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        Loc, bbArg, offset, sizes, strides);
+    newOps.push_back(newExtractOp);
+  }
+  newExtractOps = newOps;
+}
+
+// If outerMost loop of nested loop structure is `scf.forall`, need to deal with
+// DpsInit of tiled consumer
+static void fixDpsInitsOfTiledConsumer(
+    RewriterBase &rewriter, Operation *tiledConsumer,
+    ArrayRef<BlockArgument> bbArgs,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes) {
+  rewriter.setInsertionPoint(tiledConsumer);
+  Location Loc = tiledConsumer->getLoc();
+  for (auto &&[bbArg, offset, sizes, dpsInit] :
+       llvm::zip_equal(bbArgs, resultOffsets, resultSizes,
+                       cast<DestinationStyleOpInterface>(tiledConsumer)
+                           .getDpsInitsMutable())) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        Loc, bbArg, offset, sizes, strides);
+    dpsInit.set(newExtractOp.getResult());
+  }
+}
+
+// compute all results tile by given SliceOp along operand
+static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
+    RewriterBase &rewriter, TilingInterface tilableOp, unsigned operandNumber,
+    OffsetSizeAndStrideOpInterface ossSliceOp,
+    SmallVector<SmallVector<OpFoldResult>> &allResultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &allResultSizes) {
+  // 1. check all stride all 1
+  if (llvm::any_of(ossSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(ossSliceOp, "ossSliceOp has stride");
+  }
+  // 2. compute iteration domain Tile from input position
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  if (failed(tilableOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, ossSliceOp.getMixedOffsets(),
+          ossSliceOp.getMixedSizes(), iterDomainOffsets, iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        tilableOp, "can't get iter domain position from input position");
+  }
+  unsigned totalNumResultsOfConsumer = tilableOp->getNumResults();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+      totalNumResultsOfConsumer);
+  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+  // 3. compute result Tile by resultNumber
+  for (auto [idx, v] : llvm::enumerate(tilableOp->getResults())) {
+    if (failed(tilableOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultOffsets[idx], resultSizes[idx]))) {
+      return rewriter.notifyMatchFailure(
+          tilableOp,
+          "can't get result domain position from iter domain position");
+    }
+  }
+  allResultOffsets = resultOffsets;
+  allResultSizes = resultSizes;
+  return success();
+}
+
+// Considering multi-level tensor.*SliceOp maybe based on different
+// coordination, this utility computes the real OFFSET coordinated on ROOT
+// SliceOp. E.g
+//             %0 = insert_slice %1 into %2[OFFSET1] [SIZE1]
+//         %3 = insert_slice %4 into %5[OFFSET2] [SIZE2]
+//
+// where the coordination can be illustrated as follow:
+//
+//  %3 ----------------------------------
+//  |         |         |
+//  | OFFSET2 | OFFSET1 |
+//  | ------ %0         |
+//  |                   |
+//  |                   |
+//  |------------------ %1 ------ |
+//  |                   |  SIZE1  |
+//  |                   |         |
+//  |                   |         |
+//  |                   | ------- |
+//  |
+//
+// The real OFFSET of %1 coordinated on %3 is actually `OFFSET1` + `OFFSET2`
+static FailureOr<SmallVector<OpFoldResult>>
+computeRealOffsetsCoordinatedRootSliceOp(
+    RewriterBase &rewriter, Location loc,
+    OffsetSizeAndStrideOpInterface candidateSliceOp,
+    MutableArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceOpList) {
+  if (llvm::any_of(candidateSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "candidateSliceOp has stride");
+  }
+  SmallVector<OpFoldResult> realOffsets = candidateSliceOp.getMixedOffsets();
+  // real offsets equals to accumulative offsets of outer candidates
+  for (auto iter = candidateSliceOpList.rbegin(); *iter != candidateSliceOp;
+       iter++) {
+    // assert each outer candidate slice has no stride
+    if (llvm::any_of(iter->getMixedStrides(), [](OpFoldResult stride) {
+          return !isConstantIntValue(stride, 1);
+        })) {
+      return failure();
+    }
+    for (auto &&[ofr1, ofr2] :
+         llvm::zip_equal(realOffsets, iter->getMixedOffsets())) {
+      using AVE = affine::AffineValueExpr;
+      affine::AffineBuilder ab(rewriter, loc);
+      AffineExpr dim0, dim1, sym;
+      bindDims(rewriter.getContext(), dim0, dim1);
+      bindSymbols(rewriter.getContext(), sym);
+      auto aveOffset1 = AVE(dim0).bind(ofr1), aveOffset2 = AVE(dim1).bind(ofr2);
+      ofr1 = ab.add(aveOffset1, aveOffset2);
+    }
+  }
+  return realOffsets;
+}
+
+// Get the first tilable user of given Value and check its domination at the
+// same time
+static FailureOr<OpOperand *>
+getTilableConsumerOperandFirstUseVal(Value val, Operation *loopOp) {
+  for (auto &useOfval : val.getUses()) {
+    Operation *consumerOp = useOfval.getOwner();
+    // 1. Check whether consumerOp is tilable
+    if (!isa<TilingInterface>(consumerOp) ||
+        !isa<DestinationStyleOpInterface>(consumerOp))
+      continue;
+    // 2. check stay in same block with loopOp
+    if (loopOp->getBlock() != consumerOp->getBlock())
+      continue;
+    // 3. check no other user before it
+    if (failed(checkAssumptionForLoop(loopOp, consumerOp))) {
+      continue;
+    }
+    return &useOfval;
+  }
+  return failure();
+}
+
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1297,10 +1463,29 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   bool is...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list