[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Thu Apr 18 04:58:20 PDT 2024
================
@@ -1100,6 +1101,451 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
replacements};
}
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// We traverse through the use-def chain of the tensor.insert_slice op through
+/// the containing scf.for to fetch the first untiled consumer. We also return
+/// the operand number of the consumer and the result number of the scf.for
+/// being consumed.
+static Operation *
+getUntiledConsumerFromSliceDestSCFFor(tensor::InsertSliceOp candidateSliceOp,
+ unsigned &operandNumber,
+ unsigned &resultNumber) {
+ // Step 1. Fetch the corresponding output.
+ Value sliceResult = candidateSliceOp.getResult();
+ Value::user_range users = sliceResult.getUsers();
+ auto yieldOp = cast<scf::YieldOp>(*users.begin());
+ for (Value operand : yieldOp->getOperands()) {
+ if (operand == sliceResult) {
+ break;
+ }
+ resultNumber++;
+ }
+ Value resultingValue =
+ candidateSliceOp->getParentOp()->getResult(resultNumber);
+
+ // Step 2. Get users.
+ Operation *untiledConsumer;
+ for (Operation *user : resultingValue.getUsers()) {
+ // TODO(avarma): Address the case where the consumer op itself can return
+ // more than one result.
+ for (Value operand : user->getOperands()) {
+ if (operand == resultingValue) {
+ untiledConsumer = user;
+ break;
+ }
+ operandNumber++;
+ }
+ break;
+ }
+ return untiledConsumer;
+}
+
+/// A utility function that checks whether the passed operation has only one
+/// user. In case the operation is tensor.insert_slice, it checks if the user
+/// is scf.yield. It expects the passed operation to yield exactly one result.
+static bool checkAssumptionForFusingConsumer(Value result) {
+ Value::user_range users = result.getUsers();
+ if (std::distance(users.begin(), users.end()) != 1) {
+ LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+ return false;
+ }
+ if (result.getDefiningOp<tensor::InsertSliceOp>() &&
+ !isa<scf::YieldOp>(*users.begin())) {
+ LLVM_DEBUG(llvm::dbgs() << "Expected scf.yield to be the only user\n");
+ return false;
+ }
+ return true;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+ tensor::InsertSliceOp candidateSliceOp) {
+ if (!checkAssumptionForFusingConsumer(candidateSliceOp.getResult())) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "needs only scf.yield as its user");
+ }
+ // 1. Get the consumer of the source.
+ unsigned operandNumber = 0;
+ unsigned resultNumber = 0;
+ Operation *consumerOp = getUntiledConsumerFromSliceDestSCFFor(
+ candidateSliceOp, operandNumber, resultNumber);
+ if (!consumerOp)
+ return failure();
+
+ // Check that the consumer results in exactly one value.
+ // TODO: Support fusion for consumers yielding more than one result.
+ if (consumerOp->getResults().size() != 1) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "only those consumers returning exactly one result is supported");
+ }
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ // Check containing op is "scf::ForOp".
+ auto forOp = dyn_cast<scf::ForOp>(containingOp);
+ if (!forOp) {
+ return rewriter.notifyMatchFailure(containingOp,
+ "containing op is not a scf.for");
+ }
+
+ // Check containingOp has exactly one use.
+ if (!checkAssumptionForFusingConsumer(forOp.getResult(resultNumber))) {
+ return rewriter.notifyMatchFailure(forOp, "scf.for has more than 1 uses");
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(candidateSliceOp);
+
+ // Check consumer has tiling interface.
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ return rewriter.notifyMatchFailure(consumerOp,
+ "consumer is not a TileableInterface");
+ }
+
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ // Check consumer has DestinationStyleOpInterface.
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp) {
+ return rewriter.notifyMatchFailure(
+ consumerOp, "consumer op should have destination style op interface");
+ }
+
+ // Check consumer is not using scf.for's output as init.
+ SmallVector<Value> dpsInits = llvm::to_vector<4>(
----------------
ftynse wrote:
Nit: it is no longer necessary to specify the number of stack elements in `to_vector`.
https://github.com/llvm/llvm-project/pull/88712
More information about the Mlir-commits
mailing list