[Mlir-commits] [mlir] [mlir][scf] Extend fuse producer to multi-level candidates case (PR #97803)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 1 11:04:13 PDT 2024


================
@@ -949,6 +949,145 @@ mlir::scf::tileAndFuseProducerOfSlice(
                                            tileAndFuseResult->tiledOps};
 }
 
+/// Get the real producer from candidate ExtractSliceOp
+///
+/// ```
+/// %0 = producer
+/// %1 = scf.for(%arg1 = %0)
+///   %2 = extract %arg1
+///   %3 = scf.for(%arg2 = %2)
+///      %4 = extract %args2
+///      ...
+/// ```
+///
+/// @param candidateSliceOp: %4 = extract %args2
+/// @param backwardSlice: in-out parameter populated by backward extractSliceOps
+/// @return OpResult Producer : %0 = producer
+static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
+    Operation *candidateSliceOp,
+    SmallVector<tensor::ExtractSliceOp> &backwardSlice, int curDepth = 0,
+    int maxDepth = 5) {
+  if (!isa<tensor::ExtractSliceOp>(candidateSliceOp))
+    return failure();
+  // control recursive time in avoid of stack overflow
+  if (curDepth > maxDepth)
+    return failure();
+
+  auto extractOp = cast<tensor::ExtractSliceOp>(candidateSliceOp);
+  backwardSlice.push_back(extractOp);
+  Value rootSource = extractOp.getSourceMutable().get();
+
+  while (true) {
+    if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
+      if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
+              iterArg.getOwner()->getParentOp())) {
+        rootSource = outerLoop.getTiedLoopInit(iterArg)->get();
+        continue;
+      }
+      return failure();
+    } else if (auto sliceOp =
+                   rootSource.getDefiningOp<tensor::ExtractSliceOp>()) {
+      // walk up loop to find larger candidate extractSliceOp
+      return getRealProducerFromExtractSliceOp(sliceOp, backwardSlice,
+                                               curDepth + 1);
+    }
+    break;
+  }
+  return dyn_cast<OpResult>(rootSource);
+}
+
+/// Recursively find the outer nest loops of given loop(included) while the
+/// predict function succeed, sorted from outer to inner.
+///
+/// @param loop: target loop, note that this loop will be also included. I.e.
+///              if no other nest loops were found, just return itself.
+/// @param pred: predict function, the termination condition of recursive
+/// process.
+/// @return Outer Nest Loops: nest loops outside given target loop(included).
+///
+/// E.g.
+///
+/// ```
+///  %0 = scf.for()
+///    %1 = scf.for()
+///      %2 = scf.for()
+/// ```
+///
+/// If `%2 = scf.for` is given without specific prediction function, this
+/// function will return three nest loops: %0 + %1 + %2.
+static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
+    LoopLikeOpInterface loop,
+    const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
+  SmallVector<LoopLikeOpInterface> nestLoops = {loop};
+  auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
+  while (outerLoop && succeeded(pred(outerLoop))) {
+    nestLoops.push_back(outerLoop);
+    outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp());
+  }
+  // sorted from outer to inner
+  return {nestLoops.rbegin(), nestLoops.rend()};
+}
+
+/// Enhanced version for basic implementation of fusing producer, which can deal
+/// with multi-level candidates. E.g.
+///
+/// ```
+/// %0 = untiled_producer
+/// %1 = scf.for(%arg1 = %0)
+///   %2 = tensor.extract_slice %arg1
+///   %3 = scf.for(%arg2 = %2)
+///      %4 = tensor.extract_slice %args2
+///      %5 = tiled_consumer ins(%4)
+/// ```
+///
+/// This utility can fuse untiled producer at `%4 = tensor.extract_slice` within
+/// inner loop `%3 = scf.for`.
+std::optional<scf::SCFFuseProducerOfSliceResult>
+mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+                                      Operation *candidateSliceOp) {
+  SmallVector<tensor::ExtractSliceOp> backwardSlice;
+  if (failed(
+          getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice))) {
+    return std::nullopt;
+  }
+
+  std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
+  // reverse from outer to inner
+  std::reverse(backwardSlice.begin(), backwardSlice.end());
+  // multiple application of `tileAndFuseProducerOfSliceImpl`
+  for (auto &&[index, sliceOp] : llvm::enumerate(backwardSlice)) {
+    // get nest loops between next candidate sliceOp and tiled producer.
+    auto whileProducerOutOfLoopBlock =
----------------
MaheshRavishankar wrote:

Lets say you are starting with this
```
%0 = scf.for ... (%arg0 = %1) {
    %2 = scf.for ... (%arg1 = %arg0) {
        %3 ... scf.for ... (%arg2 = %arg1 ) {
          %4 ....
          %5...
           scf.yield %4
        }
        scf.yield %3
    }
    scf.yield %2
}    
```

Now if you want to yield `%5` on top of that the logic only works if `%5` is in the inner-most loop and the loop defining `%2` is yielding the result of loop defining `%3`, and loop defining `%0` is yielding the result of the loop defining `%2` (and so on...)



https://github.com/llvm/llvm-project/pull/97803


More information about the Mlir-commits mailing list