[Mlir-commits] [mlir] [mlir][scf] Extend fuse producer to multi-level candidates case (PR #97803)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 1 11:04:13 PDT 2024
================
@@ -949,6 +949,145 @@ mlir::scf::tileAndFuseProducerOfSlice(
tileAndFuseResult->tiledOps};
}
+/// Get the real producer from candidate ExtractSliceOp
+///
+/// ```
+/// %0 = producer
+/// %1 = scf.for(%arg1 = %0)
+/// %2 = extract %arg1
+/// %3 = scf.for(%arg2 = %2)
+/// %4 = extract %args2
+/// ...
+/// ```
+///
+/// @param candidateSliceOp: %4 = extract %args2
+/// @param backwardSlice: in-out parameter populated by backward extractSliceOps
+/// @return OpResult Producer : %0 = producer
+static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
+ Operation *candidateSliceOp,
+ SmallVector<tensor::ExtractSliceOp> &backwardSlice, int curDepth = 0,
+ int maxDepth = 5) {
+ if (!isa<tensor::ExtractSliceOp>(candidateSliceOp))
+ return failure();
+ // control recursive time in avoid of stack overflow
+ if (curDepth > maxDepth)
+ return failure();
+
+ auto extractOp = cast<tensor::ExtractSliceOp>(candidateSliceOp);
+ backwardSlice.push_back(extractOp);
+ Value rootSource = extractOp.getSourceMutable().get();
+
+ while (true) {
+ if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
+ if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
+ iterArg.getOwner()->getParentOp())) {
+ rootSource = outerLoop.getTiedLoopInit(iterArg)->get();
+ continue;
+ }
+ return failure();
+ } else if (auto sliceOp =
+ rootSource.getDefiningOp<tensor::ExtractSliceOp>()) {
+ // walk up loop to find larger candidate extractSliceOp
+ return getRealProducerFromExtractSliceOp(sliceOp, backwardSlice,
+ curDepth + 1);
+ }
+ break;
+ }
+ return dyn_cast<OpResult>(rootSource);
+}
+
+/// Recursively find the outer nest loops of given loop(included) while the
+/// predict function succeed, sorted from outer to inner.
+///
+/// @param loop: target loop, note that this loop will be also included. I.e.
+/// if no other nest loops were found, just return itself.
+/// @param pred: predict function, the termination condition of recursive
+/// process.
+/// @return Outer Nest Loops: nest loops outside given target loop(included).
+///
+/// E.g.
+///
+/// ```
+/// %0 = scf.for()
+/// %1 = scf.for()
+/// %2 = scf.for()
+/// ```
+///
+/// If `%2 = scf.for` is given without specific prediction function, this
+/// function will return three nest loops: %0 + %1 + %2.
+static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
+ LoopLikeOpInterface loop,
+ const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
+ SmallVector<LoopLikeOpInterface> nestLoops = {loop};
+ auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
+ while (outerLoop && succeeded(pred(outerLoop))) {
+ nestLoops.push_back(outerLoop);
+ outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp());
+ }
+ // sorted from outer to inner
+ return {nestLoops.rbegin(), nestLoops.rend()};
+}
+
+/// Enhanced version for basic implementation of fusing producer, which can deal
+/// with multi-level candidates. E.g.
+///
+/// ```
+/// %0 = untiled_producer
+/// %1 = scf.for(%arg1 = %0)
+/// %2 = tensor.extract_slice %arg1
+/// %3 = scf.for(%arg2 = %2)
+/// %4 = tensor.extract_slice %args2
+/// %5 = tiled_consumer ins(%4)
+/// ```
+///
+/// This utility can fuse untiled producer at `%4 = tensor.extract_slice` within
+/// inner loop `%3 = scf.for`.
+std::optional<scf::SCFFuseProducerOfSliceResult>
+mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+ Operation *candidateSliceOp) {
+ SmallVector<tensor::ExtractSliceOp> backwardSlice;
+ if (failed(
+ getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice))) {
+ return std::nullopt;
+ }
+
+ std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
+ // reverse from outer to inner
+ std::reverse(backwardSlice.begin(), backwardSlice.end());
+ // multiple application of `tileAndFuseProducerOfSliceImpl`
+ for (auto &&[index, sliceOp] : llvm::enumerate(backwardSlice)) {
+ // get nest loops between next candidate sliceOp and tiled producer.
+ auto whileProducerOutOfLoopBlock =
----------------
MaheshRavishankar wrote:
Lets say you are starting with this
```
%0 = scf.for ... (%arg0 = %1) {
%2 = scf.for ... (%arg1 = %arg0) {
%3 ... scf.for ... (%arg2 = %arg1 ) {
%4 ....
%5...
scf.yield %4
}
scf.yield %3
}
scf.yield %2
}
```
Now if you want to yield `%5` on top of that the logic only works if `%5` is in the inner-most loop and the loop defining `%2` is yielding the result of loop defining `%3`, and loop defining `%0` is yielding the result of the loop defining `%2` (and so on...)
https://github.com/llvm/llvm-project/pull/97803
More information about the Mlir-commits
mailing list