[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 15 14:40:04 PDT 2024
================
@@ -820,6 +874,429 @@ getUntiledProducerFromSliceSource(OpOperand *source,
return {dyn_cast<OpResult>(source->get()), destinationIterArg};
}
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+ RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of the source.
+ unsigned operandNumber = 0;
+ auto [consumerOp, destinationInitArg] =
+ getUntiledConsumerFromSliceDestSCFForall(
+ &candidateSliceOp.getDestMutable(), operandNumber);
+ if (!consumerOp)
+ return failure();
+ OpBuilder::InsertionGuard g(rewriter);
+ // Using candidateSliceOp->getParentOp() because we have the following case :-
+ // scf.forall.in_parallel {
+ // tensor.parallel_insert_slice ...
+ // }
+ rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+ Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+ // Check consumer has tiling interface.
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ llvm::outs() << "consumer is not a TileableInterface: " << *consumerOp
+ << "\n";
+ return failure();
+ }
+
+ // Check containing op is "scf::ForallOp".
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ llvm::outs() << "containing op is not a scf.forall: " << containingOp
+ << "\n";
+ return failure();
+ }
+
+ // Check consumer don't use more than one result of containingOp.
+ Value bridge(nullptr);
+ SmallVector<unsigned> operandNums;
+ for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
----------------
MaheshRavishankar wrote:
To start with, i'd use a simpler check of the containing op result having only a single use
https://github.com/llvm/llvm-project/pull/88712
More information about the Mlir-commits
mailing list