[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to single nested `scf.for` (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 9 18:51:46 PDT 2024


================
@@ -1478,7 +1504,54 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
     return failure();
   if (containingOpBlock != consumerOp->getBlock())
     return failure();
-  return &operand;
+  return operand;
+}
+
+/// Recursively find the outer nest loops of given loop(included) while the
+/// predict function succeed, sorted from outer to inner.
+///
+/// @param loop: target loop, note that this loop will be also included. I.e.
+///              if no other nest loops were found, just return itself.
+/// @param pred: predict function, the termination condition of recursive
+/// process.
+/// @return Outer Nest Loops: nest loops outside given target loop(included).
+///
+/// E.g.
+///
+/// ```
+///  %0 = scf.for()
+///    %1 = scf.for()
+///      %2 = scf.for()
+/// ```
+///
+/// If `%2 = scf.for` is given without specific prediction function, this
+/// function will return three nest loops: %0 + %1 + %2.
+static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
+    LoopLikeOpInterface loop,
+    const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
+  SmallVector<LoopLikeOpInterface> nestLoops = {loop};
+  auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
+  while (outerLoop && succeeded(pred(outerLoop))) {
+    nestLoops.push_back(outerLoop);
+    outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp());
+  }
+  // sorted from outer to inner
+  return {nestLoops.rbegin(), nestLoops.rend()};
+}
+
+/// Check if it is the ForOp that yield the result of inner loop
+static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) {
+  if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
+    Block::OpListType &opsInLoopBody = forOp.getBody()->getOperations();
+    for (auto &&[index, op] : llvm::enumerate(opsInLoopBody)) {
+      // If the orderIndex of inner loop is the last second one before the
+      // yieldOp of ForOp, the given loop must yield the result of inner loop.
+      if (isa<LoopLikeOpInterface>(op)) {
----------------
MaheshRavishankar wrote:

Yeah, lets start simpler. How about something like this.

```
auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
if (!forOp))
  return failure();
Block *body = forOp.getBody();
if (!llvm::hasSingleElement(body->without_terminator()))
  return failure();
auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
auto innerForOp = dyn_cast<scf::ForOp>(body->front());
if (!innerForOp)
  return failure();
if (innerForOp->getNumResults() != yeildOp->getNumOperands())
  return failure();
if (/*any of the innerForOp results are not yielded*/)
  return failure();
```

Also if you really want to use a "function" for the predication check can you move this to as a lambda function in `getOuterLoopsNestWhile` (and rename it to `getOuterLoopsNest`). This split seems artificial to me.


https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list