[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)
Abhishek Varma
llvmlistbot at llvm.org
Wed May 22 23:24:53 PDT 2024
================
@@ -1100,6 +1102,413 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
replacements};
}
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the only use of the result of a
+/// tensor.insert_slice op is in a scf.yield op.
+static LogicalResult
+checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
+ Value result = candidateSliceOp.getResult();
+ Value::use_range uses = result.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+ return failure();
+ }
+ OpOperand &operandUse = (*uses.begin());
+ Operation *userOp = operandUse.getOwner();
+ if (!isa<scf::YieldOp>(userOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Expected scf.yield to be the only user, but got -> "
+ << (*userOp));
+ return failure();
+ }
+ if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
+ LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
+ "be in the same block\n");
+ return failure();
+ }
+ return success();
+}
+
+/// Fetches the OpOperand of the only user (and use) of the value `val` which
+/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
+/// failure otherwise.
+static FailureOr<OpOperand *> getConsumerFromUses(Value val,
+ Block *containingOpBlock) {
+ // Step 1. Check that the value has exactly one use.
+ if (!llvm::hasSingleElement(val.getUses()))
+ return failure();
+ // Step 2. Get uses.
+ OpOperand &operand = (*val.getUses().begin());
+ Operation *consumerOp = operand.getOwner();
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp))
+ return failure();
+ if (containingOpBlock != consumerOp->getBlock())
+ return failure();
+ return &operand;
+}
+
+/// Fetch the untiled consumer of a scf.for's result which is yielded by a
+/// tensor.insert_slice. This function makes the following assumptions :
+/// 1. tensor.insert_slice has scf.yield as its only user.
+/// 2. scf.for's corresponding result has only one use.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+ if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
+ return failure();
+ Value sliceResult = candidateSliceOp.getResult();
+ // Step 1. Fetch the corresponding output.
+ OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+ unsigned resultNumber = yieldOpOperand.getOperandNumber();
+ // Step 2. Check containing op is scf.for.
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ auto forOp = dyn_cast<scf::ForOp>(containingOp);
+ if (!forOp)
+ return failure();
+ Value resultingValue = forOp->getResult(resultNumber);
+
+ return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // Step 1. Fetch the corresponding output
+ Value sliceDest = candidateSliceOp.getDest();
+ auto iterArg = dyn_cast<BlockArgument>(sliceDest);
+ if (!iterArg)
+ return failure();
+ Operation *containingOp = iterArg.getOwner()->getParentOp();
+ if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
+ return failure();
+ // Step 2. Check that the containing op is scf.forall.
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp)
+ return failure();
+ Value resultingValue =
+ forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+
+ return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
+/// This utility currently checks whether the loop either :-
+/// 1. Yields exactly one result.
+/// 2. Has consumer op as its first user and other users to be in the same
+/// containing block as that of consumer op's. Currently we clone the loop op
+/// right before the consumer op in order to maintain a valid def-use chain.
+/// This utility thus helps ensuring that no invalid IR is formed due to the
+/// same.
+static LogicalResult checkAssumptionForLoop(Operation *loopOp,
+ Operation *consumerOp) {
+ // Check if the loop op yields one result.
+ if (loopOp->getNumResults() == 1)
+ return success();
+ // Check if the consumerOp is the first user of the loopOp and if other users
+ // are in the same containing block as that of consumer op's.
+ Block *parentBlock = consumerOp->getBlock();
+ for (Operation *userOp : loopOp->getUsers()) {
+ if (userOp == consumerOp)
+ continue;
+ if (parentBlock != userOp->getBlock() ||
+ !consumerOp->isBeforeInBlock(userOp))
+ return failure();
+ }
+ return success();
+}
+
+/// A utility to fetch an untiled consumer of
+/// tensor.insert_slice/tensor.parallel_insert_slice.
+static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
+ if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
+ return getUntiledConsumerFromSlice(insertSlice);
+ } else if (auto parallelInsertSlice =
+ dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
+ return getUntiledConsumerFromSlice(parallelInsertSlice);
+ } else {
+ return failure();
+ }
+}
+
+/// After fusing consumer into scf.for we want to modify the scf.yield operation
+/// to reflect the same by returning the values yielded by the tiled consumer.
+static void
+fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
+ TilingResult &tilingResult,
+ ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+ ArrayRef<BlockArgument> bbArgs) {
+ scf::YieldOp oldTerminatorOp =
+ cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+ unsigned totalOldResults = oldTerminatorOp->getNumResults();
+ unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+ SmallVector<Value> newYieldOperands;
+ newYieldOperands.reserve(totalOldResults + totalTiledResults);
+ for (auto oldResult : oldTerminatorOp.getResults()) {
+ newYieldOperands.push_back(oldResult);
+ }
+ rewriter.setInsertionPointAfter(oldTerminatorOp);
+ Location loc = newForOp.getLoc();
+ for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+ llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
+ resultOffsets, resultSizes)) {
+ SmallVector<OpFoldResult> strides(resultOffset.size(),
+ rewriter.getIndexAttr(1));
+ Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, tiledResult, bbArg, resultOffset, resultSize, strides);
+ newYieldOperands.push_back(newInsertSliceOp);
+ }
+ rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+ rewriter.eraseOp(oldTerminatorOp);
+}
+
+/// After fusing consumer into scf.forall we want to yield each of the resulting
+/// values by the tiled consumer within scf.forall.in_parallel region.
+static void
+fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
+ SmallVector<Value> tiledResults,
+ ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+ ArrayRef<BlockArgument> bbArgs) {
+ scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
+ rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+ Location firstYieldOpLoc =
+ (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+ for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+ llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+ SmallVector<OpFoldResult> strides(resultOffset.size(),
+ rewriter.getIndexAttr(1));
+ rewriter.create<tensor::ParallelInsertSliceOp>(
+ firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
+ }
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+ Operation *candidateSliceOp) {
+ if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+ candidateSliceOp))
+ return failure();
+
+ bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
+
+ // 1. Get the consumer of scf.for for the result yielded by
+ // tensor.insert_slice/parallel_insert_slice.
+ FailureOr<OpOperand *> maybeConsumerOpOperand =
+ getUntiledConsumerFromSlice(candidateSliceOp);
+ if (failed(maybeConsumerOpOperand)) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "could not fetch consumer to fuse");
+ }
+ OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
+ Operation *consumerOp = consumerOpOperand->getOwner();
+ unsigned operandNumber = consumerOpOperand->getOperandNumber();
+ unsigned resultNumber = 0;
+ if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
+ resultNumber = producerResult.getResultNumber();
+ } else {
+ return rewriter.notifyMatchFailure(
+ consumerOp, "consumer op's operand doesn't seem to be an OpResult");
+ }
+
+ Operation *oldLoopOp = nullptr;
+ SmallVector<Value> newOuts;
+ Block *oldLoopBody = nullptr;
+ unsigned initSize = 0;
+ unsigned rank = 1;
+ if (isInsertSliceOp) {
+ auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+ oldLoopOp = forOp;
+ llvm::append_range(newOuts, forOp.getInits());
+ oldLoopBody = forOp.getBody();
+ initSize = forOp.getInits().size();
+ } else {
+ auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
+ oldLoopOp = forallOp;
+ llvm::append_range(newOuts, forallOp.getOutputs());
+ oldLoopBody = forallOp.getBody();
+ initSize = forallOp.getOutputs().size();
+ rank = forallOp.getRank();
+ }
+
+ if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+ return rewriter.notifyMatchFailure(
+ oldLoopOp, "containing loop op should either yield just one value or "
+ "have the consumer op as its first user");
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+
+ // 2. Check consumer is not using scf loop's output as init.
+ auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+ SmallVector<Value> dpsInits =
+ llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+ if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer op taking the result of scf.for as init is not supported");
+ }
+ newOuts.append(dpsInits);
+
+ Location loc = oldLoopOp->getLoc();
+
+ // 3. Create new scf loop op.
+ rewriter.setInsertionPoint(consumerOp);
+ Operation *newLoopOp = nullptr;
+ Block *newLoopBody = nullptr;
+ if (isInsertSliceOp) {
+ auto forOp = cast<scf::ForOp>(oldLoopOp);
+ auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+ forOp.getUpperBound(),
+ forOp.getStep(), newOuts);
+ newLoopOp = newForOp;
+ newLoopBody = newForOp.getBody();
+ } else {
+ auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+ auto newForallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+ newLoopOp = newForallOp;
+ rewriter.eraseOp(newForallOp.getTerminator());
+ newLoopBody = newForallOp.getBody();
+ }
+
+ // 4. Move the loop body to the new op.
+ unsigned oldNumArguments = oldLoopBody->getNumArguments();
+ rewriter.mergeBlocks(oldLoopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(oldNumArguments));
+
+ // 5. Set insertion point before terminator op of the loop and create a new
+ // tensor.insert_slice. In the scf.for case this is a clone of the
+ // candidateSliceOp whereas in the scf.forall case this is created from the
+ // operands of tensor.parallel_insert_slice.
+ tensor::InsertSliceOp clonedInsertSliceOp;
+ if (auto sliceOp =
+ dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+ auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+ rewriter.setInsertionPoint(newForallOp.getTerminator());
+ clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
+ sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
+ } else {
+ auto newForOp = cast<scf::ForOp>(newLoopOp);
+ rewriter.setInsertionPoint(newForOp.getBody()->getTerminator());
+ clonedInsertSliceOp =
+ cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
+ }
+
+ // 6.a. Clone consumer op.
+ auto newForOpBlockArgsForConsumerDest =
+ newLoopBody->getArguments().drop_front(oldNumArguments);
+ auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+ rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+ // 6.b. Replace all uses of the loop result with the result of the cloned
+ // tensor.insert_slice.
+ OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+ rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+ operandToReplace.set(clonedInsertSliceOp.getResult());
+ });
+
+ // 7 - Perform tiling of the cloned consumer and replace the operand at
+ // `operandNumber` with the source of the cloned tensor.insert_slice op.
+ auto ossSliceOp =
+ cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
+ FailureOr<TilingResult> tileAndFuseResult =
+ tensor::replaceInsertSliceWithTiledConsumer(
+ rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+ if (failed(tileAndFuseResult)) {
+ return failure();
+ }
+ rewriter.replaceAllUsesWith(
+ tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
+ clonedInsertSliceOp.getSource());
+
+ // 8 - Extract offset/sizes/strides required to create the
+ // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
+ SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+
+ // 9. Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "containingOp's result yield with stride");
+ }
+
+ // 10. Try to get iter domain position from input position.
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ clonedConsumerOp, "can't get iter domain position from input position");
+ }
+
+ // 11. Try to fetch the offset and size for all results of the cloned
+ // consumer. This would then be used to form the corresponding
+ // tensor.insert_slice/parallel_insert_slice later.
+ unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+ totalNumResultsOfConsumer);
+ SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+ for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+ if (failed(clonedConsumerOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultOffsets[idx], resultSizes[idx]))) {
+ return rewriter.notifyMatchFailure(
+ clonedConsumerOp,
+ "can't get result domain position from iter domain position");
+ }
+ }
+
+ auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
+ auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
+ if (isInsertSliceOp) {
+ auto newForOp = cast<scf::ForOp>(newLoopOp);
+ fixTerminatorSCFYield(
+ rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
+ newForOp.getBody()->getArguments().drop_front(1 + initSize));
+ } else {
+ auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+ fixTerminatorSCFInParallel(
+ rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
+ arrayRefOffsets, arrayRefSizes,
+ newForallOp.getBody()->getArguments().drop_front(rank + initSize));
+ }
+
+ // 12. Replace the result of scf loop and consumer op with new loop's results.
+ for (auto &&[oldResult, newResult] :
+ llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
+ rewriter.replaceAllUsesWith(oldResult, newResult);
+ }
+
+ for (auto &&[oldResult, newResult] :
+ llvm::zip(consumerOp->getResults(),
+ newLoopOp->getResults().drop_front(initSize))) {
+ rewriter.replaceAllUsesWith(oldResult, newResult);
+ }
+
+ // 13. Need to erase the old scf loop and the cloned consumer op.
+ rewriter.eraseOp(oldLoopOp);
----------------
Abhishek-Varma wrote:
No, I can't. It has uses because of the following :-
```
%clone_insert_slice = cloned tensor.insert_slice :
%source<32> into %dest<64> | OFFSET | STRIDES | SIZES
%tiled_operand = tensor.extract_slice:
%clone_insert_slice<64> to <32> | OFFSET | STRIDES | SIZES
```
So, to delete `%clone_insert_slice` even `%tiled_operand` needs to be deleted.
https://github.com/llvm/llvm-project/pull/88712
More information about the Mlir-commits
mailing list