[Mlir-commits] [mlir] [mlir][scf] Extend fuse producer to multi-level candidates case (PR #97803)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 23 12:33:10 PDT 2024


================
@@ -949,6 +949,145 @@ mlir::scf::tileAndFuseProducerOfSlice(
                                            tileAndFuseResult->tiledOps};
 }
 
+/// Get the real producer from candidate ExtractSliceOp
+///
+/// ```
+/// %0 = producer
+/// %1 = scf.for(%arg1 = %0)
+///   %2 = extract %arg1
+///   %3 = scf.for(%arg2 = %2)
+///      %4 = extract %args2
+///      ...
+/// ```
+///
+/// @param candidateSliceOp: %4 = extract %args2
+/// @param backwardSlice: in-out parameter populated by backward extractSliceOps
+/// @return OpResult Producer : %0 = producer
+static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
+    Operation *candidateSliceOp,
+    SmallVector<tensor::ExtractSliceOp> &backwardSlice, int curDepth = 0,
+    int maxDepth = 5) {
+  if (!isa<tensor::ExtractSliceOp>(candidateSliceOp))
+    return failure();
+  // control recursive time in avoid of stack overflow
+  if (curDepth > maxDepth)
+    return failure();
+
+  auto extractOp = cast<tensor::ExtractSliceOp>(candidateSliceOp);
+  backwardSlice.push_back(extractOp);
+  Value rootSource = extractOp.getSourceMutable().get();
+
+  while (true) {
+    if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
+      if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
+              iterArg.getOwner()->getParentOp())) {
+        rootSource = outerLoop.getTiedLoopInit(iterArg)->get();
+        continue;
+      }
+      return failure();
+    } else if (auto sliceOp =
+                   rootSource.getDefiningOp<tensor::ExtractSliceOp>()) {
+      // walk up loop to find larger candidate extractSliceOp
+      return getRealProducerFromExtractSliceOp(sliceOp, backwardSlice,
+                                               curDepth + 1);
+    }
+    break;
+  }
+  return dyn_cast<OpResult>(rootSource);
+}
+
+/// Recursively find the outer nest loops of given loop(included) while the
+/// predict function succeed, sorted from outer to inner.
+///
+/// @param loop: target loop, note that this loop will be also included. I.e.
+///              if no other nest loops were found, just return itself.
+/// @param pred: predict function, the termination condition of recursive
+/// process.
+/// @return Outer Nest Loops: nest loops outside given target loop(included).
+///
+/// E.g.
+///
+/// ```
+///  %0 = scf.for()
+///    %1 = scf.for()
+///      %2 = scf.for()
+/// ```
+///
+/// If `%2 = scf.for` is given without specific prediction function, this
+/// function will return three nest loops: %0 + %1 + %2.
+static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
+    LoopLikeOpInterface loop,
+    const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
+  SmallVector<LoopLikeOpInterface> nestLoops = {loop};
+  auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
+  while (outerLoop && succeeded(pred(outerLoop))) {
+    nestLoops.push_back(outerLoop);
+    outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp());
+  }
+  // sorted from outer to inner
+  return {nestLoops.rbegin(), nestLoops.rend()};
+}
+
+/// Enhanced version for basic implementation of fusing producer, which can deal
+/// with multi-level candidates. E.g.
+///
+/// ```
+/// %0 = untiled_producer
+/// %1 = scf.for(%arg1 = %0)
+///   %2 = tensor.extract_slice %arg1
+///   %3 = scf.for(%arg2 = %2)
+///      %4 = tensor.extract_slice %args2
+///      %5 = tiled_consumer ins(%4)
+/// ```
+///
+/// This utility can fuse untiled producer at `%4 = tensor.extract_slice` within
+/// inner loop `%3 = scf.for`.
+std::optional<scf::SCFFuseProducerOfSliceResult>
+mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+                                      Operation *candidateSliceOp) {
+  SmallVector<tensor::ExtractSliceOp> backwardSlice;
+  if (failed(
+          getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice))) {
+    return std::nullopt;
+  }
+
+  std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
+  // reverse from outer to inner
+  std::reverse(backwardSlice.begin(), backwardSlice.end());
+  // multiple application of `tileAndFuseProducerOfSliceImpl`
+  for (auto &&[index, sliceOp] : llvm::enumerate(backwardSlice)) {
+    // get nest loops between next candidate sliceOp and tiled producer.
+    auto whileProducerOutOfLoopBlock =
----------------
MaheshRavishankar wrote:

This is not enough. You need to check that the `extract_slice` is in the innermost loop of the oop nest generated by tiling. The reason is that when you are yielded the value of a fused producer that yielding only works (AFAIK)
1) The innermost loop returns the value
2) All surrounding outer loops are just yielding the result of the inner loops (and that the loop nest just pass the init values along from the outer loop to the inner loop.

So if we want that to still hold, then this loop nest needs to be checked that way...

https://github.com/llvm/llvm-project/pull/97803


More information about the Mlir-commits mailing list