[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)
Abhishek Varma
llvmlistbot at llvm.org
Mon Apr 22 03:35:33 PDT 2024
================
@@ -1100,6 +1102,475 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
replacements};
}
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the passed value has only one user.
+/// In case the defining operation is a tensor.insert_slice, it checks if the
+/// user is scf.yield.
+static LogicalResult checkAssumptionForFusingConsumer(Value result) {
+ Value::use_range uses = result.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+ return failure();
+ }
+ OpOperand &operandUse = (*uses.begin());
+ Operation *userOp = operandUse.getOwner();
+ if (!isa<scf::YieldOp>(userOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Expected scf.yield to be the only user, but got -> "
+ << (*userOp));
+ return failure();
+ }
+ return success();
+}
+
+/// Fetch the first untiled consumer of a scf.for's result which is yielded by
+/// a tensor.insert_slice. This function makes the following assumptions :-
+/// 1. tensor.insert_slice has scf.yield as its only user.
+/// 2. scf.for's correspon
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+ Value sliceResult = candidateSliceOp.getResult();
+ if (failed(checkAssumptionForFusingConsumer(candidateSliceOp.getResult()))) {
+ return failure();
+ }
+ // Step 1. Fetch the corresponding output.
+ OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+ unsigned resultNumber = yieldOpOperand.getOperandNumber();
+ // Check containing op is "scf::ForOp".
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ auto forOp = dyn_cast<scf::ForOp>(containingOp);
+ if (!forOp) {
+ return failure();
+ }
+ Value resultingValue = forOp->getResult(resultNumber);
+
+ // Check resultingValue has exactly one use.
+ if (!llvm::hasSingleElement(resultingValue.getUses())) {
+ return failure();
+ }
+
+ // Step 2. Get uses.
+ OpOperand &operand = (*resultingValue.getUses().begin());
+ return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+ tensor::InsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of the source.
+ FailureOr<OpOperand *> consumerOpOperand =
+ getUntiledConsumerFromSlice(candidateSliceOp);
+ if (failed(consumerOpOperand)) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "could not fetch consumer to fuse");
+ }
+ Operation *consumerOp = (*consumerOpOperand)->getOwner();
+ unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
+ unsigned resultNumber =
+ cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
+
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ auto forOp = static_cast<scf::ForOp>(containingOp);
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(candidateSliceOp);
+
+ // Check consumer has tiling interface.
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ return rewriter.notifyMatchFailure(consumerOp,
+ "consumer is not a TileableInterface");
+ }
+
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ // Check consumer has DestinationStyleOpInterface.
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp) {
+ return rewriter.notifyMatchFailure(
+ consumerOp, "consumer op should have destination style op interface");
+ }
+
+ // Check consumer is not using scf.for's output as init.
+ SmallVector<Value> dpsInits =
----------------
Abhishek-Varma wrote:
You're correct. But this specific check is ensuring that the single use is not found in `outs(....)`. So, cannot drop this.
https://github.com/llvm/llvm-project/pull/88712
More information about the Mlir-commits
mailing list