[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)
Abhishek Varma
llvmlistbot at llvm.org
Mon Jun 3 01:57:00 PDT 2024
================
@@ -1220,31 +1129,116 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
return success();
}
-/// A utility to fetch an untiled consumer of
-/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
- if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(insertSlice);
- } else if (auto parallelInsertSlice =
- dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(parallelInsertSlice);
- } else {
+// Traverse and collect all outer loops of given sliceOp, sorted by
+// outer-to-inner. If `untilLoop` found, stop walk through in advance.
+static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
+ OffsetSizeAndStrideOpInterface sliceOp,
+ std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
+ SmallVector<LoopLikeOpInterface> outerLoops;
+ auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
+ while (forOp) {
+ outerLoops.push_back(forOp);
+ if (untilLoop.has_value() && *untilLoop == forOp)
+ break;
+ forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+ }
+ return {outerLoops.rbegin(), outerLoops.rend()};
+}
+
+// Get the Result of top-level Loop which yield the target InsertSliceOp. E.g
+// ```
+// %1 = scf.for
+// %2 = scf.for
+// %3 = scf.for
+// ...
+// %4 = insert
+// yield %4
+// %5 = insert %3
+// yield %5
+// yield %2
+// ```
+// @param targetSliceOp: %4 = insert
+// @return Result Value: %1
+// Collected insertSliceOp List during walk including targetSliceOp:
+// %4 = insert and %5 = insert %3
+static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+getResultOfTopLevelLoopYieldInsertSliceOp(
+ OffsetSizeAndStrideOpInterface targetSliceOp, int curDepth = 0,
+ int maxDepth = 5) {
+ // control recursive time in avoid of stack overflow
+ if (curDepth > maxDepth)
+ return failure();
+
+ SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
+ candidateSliceOpList.push_back(targetSliceOp);
+ Value resultOfLoop;
+ if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(
+ targetSliceOp.getOperation())) {
+ Value destValue = sliceOp.getDest();
+ auto iterArg = cast<BlockArgument>(destValue);
+ auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
+ if (!forallOp)
+ return failure();
+ resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+ } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(
+ targetSliceOp.getOperation())) {
+ Value resultValue = sliceOp.getResult();
+ for (auto &useOperand : resultValue.getUses()) {
+ if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+ if (llvm::detail::isPresent(resultOfLoop))
+ return failure();
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+ if (!forOp)
+ return failure();
+ resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+ }
+ }
+ }
+
+ if (!llvm::detail::isPresent(resultOfLoop))
return failure();
+
+ while (true) {
+ bool walkThroughOuterLoop = false;
+ for (auto &useOperand : resultOfLoop.getUses()) {
+ if (auto sliceOp =
+ dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
+ auto resultAndSliceOpsPair =
+ getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp, curDepth + 1);
+ if (failed(resultAndSliceOpsPair))
+ return failure();
+ candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
+ (*resultAndSliceOpsPair).second.end());
+ return std::make_pair((*resultAndSliceOpsPair).first,
+ candidateSliceOpList);
+ } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+ // walk through outer loop
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+ if (!forOp)
+ return failure();
+ resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+ walkThroughOuterLoop = true;
+ break;
+ }
+ }
+ if (!walkThroughOuterLoop)
+ break;
}
+ return std::make_pair(resultOfLoop, candidateSliceOpList);
}
/// After fusing consumer into scf.for we want to modify the scf.yield operation
/// to reflect the same by returning the values yielded by the tiled consumer.
static void
fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
- TilingResult &tilingResult,
- ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+ ResultRange tilingResult,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes,
----------------
Abhishek-Varma wrote:
I guess here @ftynse can comment better since I'm learning too - but I believe we were to use `ArrayRef` here since we aren't mutating `resultOffsets` and `resultSizes` inside the function.
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list