[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

Abhishek Varma llvmlistbot at llvm.org
Mon Jun 3 01:56:59 PDT 2024

@@ -1220,31 +1129,116 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
   return success();
-/// A utility to fetch an untiled consumer of
-/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
-  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(insertSlice);
-  } else if (auto parallelInsertSlice =
-                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(parallelInsertSlice);
-  } else {
+// Traverse and collect all outer loops of given sliceOp, sorted by
+// outer-to-inner. If `untilLoop` found, stop walk through in advance.
+static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
+    OffsetSizeAndStrideOpInterface sliceOp,
+    std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
+  SmallVector<LoopLikeOpInterface> outerLoops;
+  auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
+  while (forOp) {
+    outerLoops.push_back(forOp);
+    if (untilLoop.has_value() && *untilLoop == forOp)
+      break;
+    forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+  }
+  return {outerLoops.rbegin(), outerLoops.rend()};
+// Get the Result of top-level Loop which yield the target InsertSliceOp. E.g
+// ```
+// %1 = scf.for
+//  %2 = scf.for
+//   %3 = scf.for
+//      ...
+//      %4 = insert
+//      yield %4
+//   %5 = insert %3
+//   yield %5
+//  yield %2
+// ```
+// @param targetSliceOp: %4 = insert
+// @return Result Value: %1
+//         Collected insertSliceOp List during walk including targetSliceOp:
+//                %4 = insert and %5 = insert %3
+static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+    OffsetSizeAndStrideOpInterface targetSliceOp, int curDepth = 0,
+    int maxDepth = 5) {
+  // control recursive time in avoid of stack overflow
+  if (curDepth > maxDepth)
+    return failure();
+  SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
+  candidateSliceOpList.push_back(targetSliceOp);
+  Value resultOfLoop;
+  if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(
+          targetSliceOp.getOperation())) {
Abhishek-Varma wrote:

Although you're sending in `OffsetSizeAndStrideOpInterface` but I don't think the Interface's methods are being used in anyway.

You might as well send in an `Operation *` and just add an assert for the `targetSliceOp` being a `OffsetSizeAndStrideOpInterface` instead.


More information about the Mlir-commits mailing list