[Mlir-commits] [mlir] [mlir][SCF] Add `scf::tileAndFuseConsumer` that tiles a consumer into a given tiled loop nest. (PR #167634)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 20 11:52:11 PST 2025


================
@@ -2428,11 +2391,173 @@ mlir::scf::tileAndFuseConsumerOfSlices(
       llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
         return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
       });
+  auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
   return scf::SCFFuseConsumerOfSliceResult{
-      std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
+      std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
       std::move(tileAndFuseResult->tiledOps)};
 }
 
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlices(
+    RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
+    MutableArrayRef<LoopLikeOpInterface> loops) {
+  if (candidateSlices.empty()) {
+    return rewriter.notifyMatchFailure(
+        rewriter.getUnknownLoc(),
+        "no candidate slices provided for consumer fusion");
+  }
+  // Return if `loops` is empty, return an error for now. Caller is expected
+  // to handle this case.
+  if (loops.empty()) {
+    return rewriter.notifyMatchFailure(
+        candidateSlices.front(),
+        "cannot call tile and fuse consumer with an empty loop nest");
+  }
+
+  if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
+        llvm::all_of(candidateSlices,
+                     llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
+    return rewriter.notifyMatchFailure(
+        candidateSlices.front(),
+        "candidates slices need to be all `tensor.extract_slice`s or "
+        "`tensor.parallel_insert_slice`s");
+  }
+
+  // Get the consumer of scf.for for the result yielded by
+  // tensor.insert_slice/parallel_insert_slice.
+  SmallVector<OpOperand *> consumerOpOperands;
+  Operation *consumerOp;
+  {
+    FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
+        getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
+    if (failed(maybeConsumerOpOperand)) {
+      return rewriter.notifyMatchFailure(candidateSlices.front(),
+                                         "could not fetch consumer to fuse");
+    }
+    std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
+    consumerOp = consumerOpOperands.front()->getOwner();
+  }
+
+  return tileAndFuseConsumerOfSlicesImpl(
+      rewriter, consumerOp, consumerOpOperands, candidateSlices, loops);
+}
+
+/// For a given `result` of a `forallOp` return the
+/// `tensor.parallel_insert_slice` op (or combining op) that is used to
+/// construct this result.
+static std::optional<Operation *>
+getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) {
+  if (result.getOwner() != forallOp)
+    return std::nullopt;
+  BlockArgument bbArg = forallOp.getTiedBlockArgument(result);
+  SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg);
+  // If the number of combining ops is not 1, then this is unexpected. Return
+  // nullopt.
+  if (combiningOps.size() != 1) {
+    return std::nullopt;
+  }
+  return combiningOps[0];
+}
+
+/// For a given result of the loop nest that is a tiled loop nest, return the
+/// insert slice-like op that is used for consumer fusion
+std::optional<Operation *>
+getProducingInsertSliceLikeOp(OpResult result,
+                              ArrayRef<LoopLikeOpInterface> loops) {
+  assert(!loops.empty() && "Expected loops to be not empty");
+  LoopLikeOpInterface outermostLoop = loops.front();
+
+  if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) {
----------------
MaheshRavishankar wrote:

Actually I know why I needed this. `loops.front()` is a `const`, and `getOperation` doesnt have `const` on it so I get the following error 
```
error: 'this' argument to member function 'getOperation' has type 'const mlir::LoopLikeOpInterface', but function is not marked const
```
So I needed to create a temporary without the const. 

This choice by MLIR (https://mlir.llvm.org/docs/Rationale/UsageOfConst/) seems to keep running into issues like this. I dont disagree with the choice itself, but it has some edges like this.

https://github.com/llvm/llvm-project/pull/167634


More information about the Mlir-commits mailing list