[Mlir-commits] [mlir] [mlir][scf] Extend fuse producer to multi-level candidates case (PR #97803)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 23 18:48:56 PDT 2024
================
@@ -949,6 +949,145 @@ mlir::scf::tileAndFuseProducerOfSlice(
tileAndFuseResult->tiledOps};
}
+/// Get the real producer from candidate ExtractSliceOp
+///
+/// ```
+/// %0 = producer
+/// %1 = scf.for(%arg1 = %0)
+/// %2 = extract %arg1
+/// %3 = scf.for(%arg2 = %2)
+/// %4 = extract %args2
+/// ...
+/// ```
+///
+/// @param candidateSliceOp: %4 = extract %args2
+/// @param backwardSlice: in-out parameter populated by backward extractSliceOps
+/// @return OpResult Producer : %0 = producer
+static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
+ Operation *candidateSliceOp,
+ SmallVector<tensor::ExtractSliceOp> &backwardSlice, int curDepth = 0,
+ int maxDepth = 5) {
+ if (!isa<tensor::ExtractSliceOp>(candidateSliceOp))
+ return failure();
+ // control recursive time in avoid of stack overflow
+ if (curDepth > maxDepth)
+ return failure();
+
+ auto extractOp = cast<tensor::ExtractSliceOp>(candidateSliceOp);
+ backwardSlice.push_back(extractOp);
+ Value rootSource = extractOp.getSourceMutable().get();
+
+ while (true) {
+ if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
+ if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
+ iterArg.getOwner()->getParentOp())) {
+ rootSource = outerLoop.getTiedLoopInit(iterArg)->get();
+ continue;
+ }
+ return failure();
+ } else if (auto sliceOp =
+ rootSource.getDefiningOp<tensor::ExtractSliceOp>()) {
+ // walk up loop to find larger candidate extractSliceOp
+ return getRealProducerFromExtractSliceOp(sliceOp, backwardSlice,
+ curDepth + 1);
+ }
+ break;
+ }
+ return dyn_cast<OpResult>(rootSource);
+}
+
+/// Recursively find the outer nest loops of given loop(included) while the
+/// predict function succeed, sorted from outer to inner.
+///
+/// @param loop: target loop, note that this loop will be also included. I.e.
+/// if no other nest loops were found, just return itself.
+/// @param pred: predict function, the termination condition of recursive
+/// process.
+/// @return Outer Nest Loops: nest loops outside given target loop(included).
+///
+/// E.g.
+///
+/// ```
+/// %0 = scf.for()
+/// %1 = scf.for()
+/// %2 = scf.for()
+/// ```
+///
+/// If `%2 = scf.for` is given without specific prediction function, this
+/// function will return three nest loops: %0 + %1 + %2.
+static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
+ LoopLikeOpInterface loop,
+ const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
----------------
Yun-Fly wrote:
Like `ArrayRef()::take_while(PredicateT Pred)` which returns the first N elements of this Array that satisfy the given predicate, this callback will help to filter out which kinds of outer loops we focus on.
Besides, it is a shared utility also used in another fusing producer patch([here](https://github.com/llvm/llvm-project/pull/94190/files#diff-088c31bdb684f78d48f7607d0007daa4e1fba53b0361def9841f79631d8a3d5dR1278)), so it maybe better to keep this callback for reuse.
https://github.com/llvm/llvm-project/pull/97803
More information about the Mlir-commits
mailing list