[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 15 14:40:06 PDT 2024
================
@@ -820,6 +874,429 @@ getUntiledProducerFromSliceSource(OpOperand *source,
return {dyn_cast<OpResult>(source->get()), destinationIterArg};
}
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+ RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of the source.
+ unsigned operandNumber = 0;
+ auto [consumerOp, destinationInitArg] =
+ getUntiledConsumerFromSliceDestSCFForall(
+ &candidateSliceOp.getDestMutable(), operandNumber);
+ if (!consumerOp)
+ return failure();
+ OpBuilder::InsertionGuard g(rewriter);
+ // Using candidateSliceOp->getParentOp() because we have the following case :-
+ // scf.forall.in_parallel {
+ // tensor.parallel_insert_slice ...
+ // }
+ rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+ Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+ // Check consumer has tiling interface.
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ llvm::outs() << "consumer is not a TileableInterface: " << *consumerOp
----------------
MaheshRavishankar wrote:
Dont use `llvm::outs()`. Use
```
consumerOp.value()->emitOpError("...');
```
or
```
rewriter.notifyMatchFailure(consumerOp, "...");
```
here and everywhere else.
https://github.com/llvm/llvm-project/pull/88712
More information about the Mlir-commits
mailing list