[Mlir-commits] [mlir] [mlir][scf] Extend fuse producer to multi-level candidates case (PR #97803)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 23 18:48:56 PDT 2024


================
@@ -949,6 +949,145 @@ mlir::scf::tileAndFuseProducerOfSlice(
                                            tileAndFuseResult->tiledOps};
 }
 
+/// Get the real producer from candidate ExtractSliceOp
+///
+/// ```
+/// %0 = producer
+/// %1 = scf.for(%arg1 = %0)
+///   %2 = extract %arg1
+///   %3 = scf.for(%arg2 = %2)
+///      %4 = extract %args2
+///      ...
+/// ```
+///
+/// @param candidateSliceOp: %4 = extract %args2
+/// @param backwardSlice: in-out parameter populated by backward extractSliceOps
+/// @return OpResult Producer : %0 = producer
+static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
+    Operation *candidateSliceOp,
+    SmallVector<tensor::ExtractSliceOp> &backwardSlice, int curDepth = 0,
+    int maxDepth = 5) {
+  if (!isa<tensor::ExtractSliceOp>(candidateSliceOp))
+    return failure();
+  // control recursive time in avoid of stack overflow
+  if (curDepth > maxDepth)
+    return failure();
+
+  auto extractOp = cast<tensor::ExtractSliceOp>(candidateSliceOp);
+  backwardSlice.push_back(extractOp);
+  Value rootSource = extractOp.getSourceMutable().get();
+
+  while (true) {
+    if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
+      if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
+              iterArg.getOwner()->getParentOp())) {
+        rootSource = outerLoop.getTiedLoopInit(iterArg)->get();
+        continue;
+      }
+      return failure();
+    } else if (auto sliceOp =
+                   rootSource.getDefiningOp<tensor::ExtractSliceOp>()) {
+      // walk up loop to find larger candidate extractSliceOp
+      return getRealProducerFromExtractSliceOp(sliceOp, backwardSlice,
+                                               curDepth + 1);
+    }
+    break;
+  }
+  return dyn_cast<OpResult>(rootSource);
+}
+
+/// Recursively find the outer nest loops of given loop(included) while the
+/// predict function succeed, sorted from outer to inner.
+///
+/// @param loop: target loop, note that this loop will be also included. I.e.
+///              if no other nest loops were found, just return itself.
+/// @param pred: predict function, the termination condition of recursive
+/// process.
+/// @return Outer Nest Loops: nest loops outside given target loop(included).
+///
+/// E.g.
+///
+/// ```
+///  %0 = scf.for()
+///    %1 = scf.for()
+///      %2 = scf.for()
+/// ```
+///
+/// If `%2 = scf.for` is given without specific prediction function, this
+/// function will return three nest loops: %0 + %1 + %2.
+static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
+    LoopLikeOpInterface loop,
+    const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
----------------
Yun-Fly wrote:

Like `ArrayRef()::take_while(PredicateT Pred)` which returns the first N elements of this Array that satisfy the given predicate, this callback will help to filter out which kinds of outer loops we focus on.

Besides, it is a shared utility also used in another fusing producer patch([here](https://github.com/llvm/llvm-project/pull/94190/files#diff-088c31bdb684f78d48f7607d0007daa4e1fba53b0361def9841f79631d8a3d5dR1278)), so it maybe better to keep this callback for reuse.

https://github.com/llvm/llvm-project/pull/97803


More information about the Mlir-commits mailing list