[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Fri May 17 03:36:39 PDT 2024
================
@@ -1100,6 +1102,398 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
replacements};
}
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the passed value has only one user.
+/// In case the defining operation is a tensor.insert_slice, it checks if the
+/// user is scf.yield.
+static LogicalResult checkAssumptionForFusingConsumer(Value result) {
+ Value::use_range uses = result.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+ return failure();
+ }
+ OpOperand &operandUse = (*uses.begin());
+ Operation *userOp = operandUse.getOwner();
+ if (!isa<scf::YieldOp>(userOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Expected scf.yield to be the only user, but got -> "
+ << (*userOp));
+ return failure();
+ }
+ if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
+ LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
+ "be in the same block\n");
+ return failure();
+ }
+ return success();
+}
+
+/// Fetch the first untiled consumer of a scf.for's result which is yielded by
+/// a tensor.insert_slice. This function makes the following assumptions :-
+/// 1. tensor.insert_slice has scf.yield as its only user.
+/// 2. scf.for's corresponding result has only one use.
+static OpOperand *
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+ Value sliceResult = candidateSliceOp.getResult();
+ if (failed(checkAssumptionForFusingConsumer(candidateSliceOp.getResult()))) {
+ return nullptr;
+ }
+ // Step 1. Fetch the corresponding output.
+ OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+ unsigned resultNumber = yieldOpOperand.getOperandNumber();
+ // Step 2. Check containing op is scf.for.
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ auto forOp = dyn_cast<scf::ForOp>(containingOp);
+ if (!forOp) {
+ return nullptr;
+ }
+ Value resultingValue = forOp->getResult(resultNumber);
+
+ // Step 3. Check resulting value of scf.for has exactly one use.
+ if (!llvm::hasSingleElement(resultingValue.getUses())) {
+ return nullptr;
+ }
+
+ // Step 4. Get uses.
+ OpOperand &operand = (*resultingValue.getUses().begin());
+ Operation *consumerOp = operand.getOwner();
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp)) {
+ return nullptr;
+ }
+ if (containingOp->getBlock() != consumerOp->getBlock()) {
+ return nullptr;
+ }
+ return &operand;
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static OpOperand *
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // Step 1. Fetch the corresponding output
+ Value sliceDest = candidateSliceOp.getDest();
+ auto iterArg = cast<BlockArgument>(sliceDest);
+ Operation *containingOp = iterArg.getOwner()->getParentOp();
+ // Step 2. Check that the containing op is scf.forall.
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ return nullptr;
+ }
+ Value resultingValue =
+ forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+ // Step 3. Check resulting value of scf.forall has exactly one use.
+ Value::use_range uses = resultingValue.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ return nullptr;
+ }
+
+ // Step 4. Get uses.
+ OpOperand &operand = (*resultingValue.getUses().begin());
+ Operation *consumerOp = operand.getOwner();
+ // TODO: We have to init result of consumer before scf.forall, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp)) {
+ return nullptr;
+ }
+ if (containingOp->getBlock() != consumerOp->getBlock()) {
+ return nullptr;
+ }
+ return &operand;
+}
+
+/// This utility currently checks whether the loop either :-
+/// 1. Yields exactly one result.
+/// 2. Has consumer op as its first user and other users to be in the same
+/// containing block as that of consumer op's. Currently we clone the loop op
+/// right before the consumer op in order to maintain a valid def-use chain.
+/// This utility thus helps ensuring that no invalid IR is formed due to the
+/// same.
+static LogicalResult checkAssumptionForLoop(Operation *loopOp,
+ Operation *consumerOp) {
+ // Check if the loop op yields one result.
+ if (loopOp->getNumResults() == 1)
+ return success();
+ // Check if the consumerOp is the first user of the loopOp and if other users
+ // are in the same containing block as that of consumer op's.
+ Block *parentBlock = consumerOp->getBlock();
+ for (Operation *userOp : loopOp->getUsers()) {
+ if (userOp == consumerOp)
+ continue;
+ if (parentBlock != userOp->getBlock() ||
+ !consumerOp->isBeforeInBlock(userOp))
+ return failure();
+ }
+ return success();
+}
+
+static OpOperand *getUntiledConsumerFromSlice(Operation *sliceOp) {
+ if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
+ return getUntiledConsumerFromSlice(insertSlice);
+ } else if (auto parallelInsertSlice =
+ dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
+ return getUntiledConsumerFromSlice(parallelInsertSlice);
+ } else {
+ return nullptr;
+ }
+}
+
+static void
+fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
+ TilingResult tilingResult,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+ SmallVector<OpFoldResult> &strides, unsigned initSize) {
+ scf::YieldOp oldTerminatorOp =
+ cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+ SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+ rewriter.setInsertionPointAfter(oldTerminatorOp);
+ MutableArrayRef<BlockArgument> bbArgs = newForOp.getBody()->getArguments();
+ Location loc = newForOp.getLoc();
+ for (auto [idx, v] :
+ llvm::enumerate(tilingResult.tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+ rewriter.getIndexAttr(1));
+ Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, v, bbArgs[1 + initSize + idx], resultOffsets[idx],
+ resultSizes[idx], strides);
+ newYieldOperands.push_back(newInsertSliceOp);
+ }
+ rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+ rewriter.eraseOp(oldTerminatorOp);
+}
+
+static void fixTerminatorSCFInParallel(
+ RewriterBase &rewriter, scf::ForallOp newForallOp,
+ TilingResult tilingResult,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+ SmallVector<OpFoldResult> &strides, unsigned initSize, unsigned rank) {
+ scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
+ rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+ Location firstYieldOpLoc =
+ (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+ MutableArrayRef<BlockArgument> bbArgs = newForallOp.getBody()->getArguments();
+ for (auto [idx, v] :
+ llvm::enumerate(tilingResult.tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+ rewriter.getIndexAttr(1));
+ rewriter.create<tensor::ParallelInsertSliceOp>(
+ firstYieldOpLoc, v, bbArgs[rank + initSize + idx], resultOffsets[idx],
+ resultSizes[idx], strides);
+ }
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+ Operation *candidateSliceOp) {
+ if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+ candidateSliceOp))
+ return failure();
+
+ bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
+
+ // 1. Get the consumer of scf.for for the result yielded by
+ // tensor.insert_slice/parallel_insert_slice.
+ OpOperand *consumerOpOperand = getUntiledConsumerFromSlice(candidateSliceOp);
+ if (!consumerOpOperand) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "could not fetch consumer to fuse");
+ }
+ Operation *consumerOp = consumerOpOperand->getOwner();
+ unsigned operandNumber = consumerOpOperand->getOperandNumber();
+ unsigned resultNumber =
+ cast<OpResult>(consumerOpOperand->get()).getResultNumber();
+
+ Operation *oldLoopOp = nullptr;
+ SmallVector<Value> newOuts;
+ Block *oldLoopBody = nullptr;
+ unsigned initSize = 0;
+ unsigned rank = 1;
+ if (isInsertSliceOp) {
+ auto forOp = candidateSliceOp->template getParentOfType<scf::ForOp>();
----------------
ftynse wrote:
```suggestion
auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
```
Nit: I believe this is no longer necessary, here and below.
https://github.com/llvm/llvm-project/pull/88712
More information about the Mlir-commits
mailing list