[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)
Abhishek Varma
llvmlistbot at llvm.org
Mon Jun 3 01:56:59 PDT 2024
================
@@ -1220,31 +1129,116 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
return success();
}
-/// A utility to fetch an untiled consumer of
-/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
- if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(insertSlice);
- } else if (auto parallelInsertSlice =
- dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(parallelInsertSlice);
- } else {
+// Traverse and collect all outer loops of given sliceOp, sorted by
+// outer-to-inner. If `untilLoop` found, stop walk through in advance.
+static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
+ OffsetSizeAndStrideOpInterface sliceOp,
+ std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
+ SmallVector<LoopLikeOpInterface> outerLoops;
+ auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
+ while (forOp) {
+ outerLoops.push_back(forOp);
+ if (untilLoop.has_value() && *untilLoop == forOp)
+ break;
+ forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+ }
+ return {outerLoops.rbegin(), outerLoops.rend()};
+}
+
+// Get the Result of top-level Loop which yield the target InsertSliceOp. E.g
+// ```
+// %1 = scf.for
+// %2 = scf.for
+// %3 = scf.for
+// ...
+// %4 = insert
+// yield %4
+// %5 = insert %3
+// yield %5
+// yield %2
+// ```
+// @param targetSliceOp: %4 = insert
+// @return Result Value: %1
+// Collected insertSliceOp List during walk including targetSliceOp:
+// %4 = insert and %5 = insert %3
+static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+getResultOfTopLevelLoopYieldInsertSliceOp(
+ OffsetSizeAndStrideOpInterface targetSliceOp, int curDepth = 0,
+ int maxDepth = 5) {
+ // control recursive time in avoid of stack overflow
+ if (curDepth > maxDepth)
+ return failure();
+
+ SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
+ candidateSliceOpList.push_back(targetSliceOp);
+ Value resultOfLoop;
+ if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(
+ targetSliceOp.getOperation())) {
+ Value destValue = sliceOp.getDest();
+ auto iterArg = cast<BlockArgument>(destValue);
+ auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
+ if (!forallOp)
+ return failure();
+ resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+ } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(
+ targetSliceOp.getOperation())) {
+ Value resultValue = sliceOp.getResult();
+ for (auto &useOperand : resultValue.getUses()) {
+ if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+ if (llvm::detail::isPresent(resultOfLoop))
+ return failure();
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+ if (!forOp)
+ return failure();
+ resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+ }
+ }
+ }
+
+ if (!llvm::detail::isPresent(resultOfLoop))
return failure();
+
+ while (true) {
+ bool walkThroughOuterLoop = false;
+ for (auto &useOperand : resultOfLoop.getUses()) {
----------------
Abhishek-Varma wrote:
Here as well `auto` needs to be expanded.
Improves the readability of the code.
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list