[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 17 21:50:30 PDT 2024
================
@@ -1100,6 +1101,428 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
replacements};
}
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// In the following function `source` is the source operand of
+/// tensor.insert_slice op. We traverse through the use-def chain of the same
+/// through the containing scf.for to fetch the first untiled consumer.
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFFor(OpOperand &source,
+ unsigned &operandNumber) {
+ // Step 1. Fetch the corresponding output
+ // TODO(avarma): Make it generic for multiple values yielding scf.for.
+ unsigned yieldOperandNumber = source.getOperandNumber();
+ Value resultingValue =
+ source.getOwner()->getParentOp()->getResult(yieldOperandNumber);
+
+ // Step 3. Get users.
+ std::optional<OpOperand *> destinationIterArg;
+ Operation *untiledConsumer;
+ for (Operation *user : resultingValue.getUsers()) {
+ // TODO(avarma): Address the case where the consumer op itself can return
+ // more than one result.
+ for (Value operand : user->getOperands()) {
+ if (operand == resultingValue) {
+ untiledConsumer = user;
+ break;
+ }
+ operandNumber++;
+ }
+ break;
+ }
+ return {untiledConsumer, destinationIterArg};
+}
+
+static bool checkAssumptionForFusingConsumer(Operation *op) {
+ Value result = op->getResult(0);
+ Value::user_range users = result.getUsers();
+ if (std::distance(users.begin(), users.end()) != 1) {
+ LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+ return false;
+ }
+ if (isa<tensor::InsertSliceOp>(op) && !isa<scf::YieldOp>(*users.begin())) {
+ LLVM_DEBUG(llvm::dbgs() << "Expected scf.yield to be the only user\n");
+ return false;
+ }
+ return true;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+ tensor::InsertSliceOp candidateSliceOp) {
+ // ASSUMING THAT YIELD OP IS ONLY YIELDING JUST ONE VALUE.
+ if (!checkAssumptionForFusingConsumer(candidateSliceOp)) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "needs only scf.yield as its user");
+ }
+ // 1. Get the consumer of the source.
+ unsigned operandNumber = 0;
+ auto [consumerOp, destinationInitArg] = getUntiledConsumerFromSliceDestSCFFor(
+ candidateSliceOp->getOpOperand(0), operandNumber);
+ if (!consumerOp)
+ return failure();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(candidateSliceOp);
+
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ // Check consumer has tiling interface.
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ return rewriter.notifyMatchFailure(consumerOp,
+ "consumer is not a TileableInterface");
+ }
+
+ // Check containing op is "scf::ForOp".
+ auto forOp = dyn_cast<scf::ForOp>(containingOp);
+ if (!forOp) {
+ return rewriter.notifyMatchFailure(containingOp,
+ "containing op is not a scf.for");
+ }
+
+ // Check containingOp has exactly one use.
+ assert(forOp.getResults().size() == 1 &&
----------------
MaheshRavishankar wrote:
If you want to support `results.size() > 1` later thats fine, but do not assert on that. Fold that into the check method below (though I think it is really easy to avoid this check right now)
https://github.com/llvm/llvm-project/pull/88712
More information about the Mlir-commits
mailing list