[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)
Quinn Dawkins
llvmlistbot at llvm.org
Wed May 22 10:53:48 PDT 2024
================
@@ -1100,6 +1103,408 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
replacements};
}
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the only use of the result of a
+/// tensor.insert_slice op is in a scf.yield op.
+static LogicalResult
+checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
+ Value result = candidateSliceOp.getResult();
+ Value::use_range uses = result.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+ return failure();
+ }
+ OpOperand &operandUse = (*uses.begin());
+ Operation *userOp = operandUse.getOwner();
+ if (!isa<scf::YieldOp>(userOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Expected scf.yield to be the only user, but got -> "
+ << (*userOp));
+ return failure();
+ }
+ if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
+ LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
+ "be in the same block\n");
+ return failure();
+ }
+ return success();
+}
+
+/// Fetches the OpOperand of the only user (and use) of the value `val` which
+/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
+/// failure otherwise.
+static FailureOr<OpOperand *> getConsumerFromUses(Value val,
+ Block *containingOpBlock) {
+ // Step 1. Check that the value has exactly one use.
+ if (!llvm::hasSingleElement(val.getUses()))
+ return failure();
+ // Step 2. Get uses.
+ OpOperand &operand = (*val.getUses().begin());
+ Operation *consumerOp = operand.getOwner();
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp))
+ return failure();
+ if (containingOpBlock != consumerOp->getBlock())
+ return failure();
+ return &operand;
+}
+
+/// Fetch the untiled consumer of a scf.for's result which is yielded by a
+/// tensor.insert_slice. This function makes the following assumptions :
+/// 1. tensor.insert_slice has scf.yield as its only user.
+/// 2. scf.for's corresponding result has only one use.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+ if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
+ return failure();
+ Value sliceResult = candidateSliceOp.getResult();
+ // Step 1. Fetch the corresponding output.
+ OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+ unsigned resultNumber = yieldOpOperand.getOperandNumber();
+ // Step 2. Check containing op is scf.for.
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ auto forOp = dyn_cast<scf::ForOp>(containingOp);
+ if (!forOp)
+ return failure();
+ Value resultingValue = forOp->getResult(resultNumber);
+
+ return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // Step 1. Fetch the corresponding output
+ Value sliceDest = candidateSliceOp.getDest();
+ auto iterArg = dyn_cast<BlockArgument>(sliceDest);
+ if (!iterArg)
+ return failure();
+ Operation *containingOp = iterArg.getOwner()->getParentOp();
+ if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
+ return failure();
+ // Step 2. Check that the containing op is scf.forall.
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp)
+ return failure();
+ Value resultingValue =
+ forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+
+ return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
+/// This utility currently checks whether the loop either :-
+/// 1. Yields exactly one result.
+/// 2. Has consumer op as its first user and other users to be in the same
+/// containing block as that of consumer op's. Currently we clone the loop op
+/// right before the consumer op in order to maintain a valid def-use chain.
+/// This utility thus helps ensuring that no invalid IR is formed due to the
+/// same.
+static LogicalResult checkAssumptionForLoop(Operation *loopOp,
+ Operation *consumerOp) {
+ // Check if the loop op yields one result.
+ if (loopOp->getNumResults() == 1)
+ return success();
+ // Check if the consumerOp is the first user of the loopOp and if other users
+ // are in the same containing block as that of consumer op's.
+ Block *parentBlock = consumerOp->getBlock();
+ for (Operation *userOp : loopOp->getUsers()) {
+ if (userOp == consumerOp)
+ continue;
+ if (parentBlock != userOp->getBlock() ||
+ !consumerOp->isBeforeInBlock(userOp))
+ return failure();
+ }
+ return success();
+}
+
+/// A utility to fetch an untiled consumer of
+/// tensor.insert_slice/tensor.parallel_insert_slice.
+static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
+ if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
+ return getUntiledConsumerFromSlice(insertSlice);
+ } else if (auto parallelInsertSlice =
+ dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
+ return getUntiledConsumerFromSlice(parallelInsertSlice);
+ } else {
+ return failure();
+ }
+}
+
+/// After fusing consumer into scf.for we want to modify the scf.yield operation
+/// to reflect the same by returning the values yielded by the tiled consumer.
+static void
+fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
+ TilingResult &tilingResult,
+ ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+ ArrayRef<BlockArgument> bbArgs) {
+ scf::YieldOp oldTerminatorOp =
+ cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+ unsigned totalOldResults = oldTerminatorOp->getNumResults();
+ unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+ SmallVector<Value> newYieldOperands;
+ newYieldOperands.reserve(totalOldResults + totalTiledResults);
+ for (auto oldResult : oldTerminatorOp.getResults()) {
+ newYieldOperands.push_back(oldResult);
+ }
+ rewriter.setInsertionPointAfter(oldTerminatorOp);
+ Location loc = newForOp.getLoc();
+ for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+ llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
+ resultOffsets, resultSizes)) {
+ SmallVector<OpFoldResult> strides(resultOffset.size(),
+ rewriter.getIndexAttr(1));
+ Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, tiledResult, bbArg, resultOffset, resultSize, strides);
+ newYieldOperands.push_back(newInsertSliceOp);
+ }
+ rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+ rewriter.eraseOp(oldTerminatorOp);
+}
+
+/// After fusing consumer into scf.forall we want to yield each of the resulting
+/// values by the tiled consumer within scf.forall.in_parallel region.
+static void
+fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
+ SmallVector<Value> tiledResults,
+ ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+ ArrayRef<BlockArgument> bbArgs) {
+ scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
+ rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+ Location firstYieldOpLoc =
+ (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+ for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+ llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+ SmallVector<OpFoldResult> strides(resultOffset.size(),
+ rewriter.getIndexAttr(1));
+ rewriter.create<tensor::ParallelInsertSliceOp>(
+ firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
+ }
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+ Operation *candidateSliceOp) {
+ if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+ candidateSliceOp))
+ return failure();
+
+ bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
+
+ // 1. Get the consumer of scf.for for the result yielded by
+ // tensor.insert_slice/parallel_insert_slice.
+ FailureOr<OpOperand *> maybeConsumerOpOperand =
+ getUntiledConsumerFromSlice(candidateSliceOp);
+ if (failed(maybeConsumerOpOperand)) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "could not fetch consumer to fuse");
+ }
+ OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
+ Operation *consumerOp = consumerOpOperand->getOwner();
+ unsigned operandNumber = consumerOpOperand->getOperandNumber();
+ unsigned resultNumber = 0;
+ if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
+ resultNumber = producerResult.getResultNumber();
+ } else {
+ return rewriter.notifyMatchFailure(
+ consumerOp, "consumer op's operand doesn't seem to be an OpResult");
+ }
+
+ Operation *oldLoopOp = nullptr;
+ SmallVector<Value> newOuts;
+ Block *oldLoopBody = nullptr;
+ unsigned initSize = 0;
+ unsigned rank = 1;
+ if (isInsertSliceOp) {
+ auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+ oldLoopOp = forOp;
+ llvm::append_range(newOuts, forOp.getInits());
+ oldLoopBody = forOp.getBody();
+ initSize = forOp.getInits().size();
+ } else {
+ auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
+ oldLoopOp = forallOp;
+ llvm::append_range(newOuts, forallOp.getOutputs());
+ oldLoopBody = forallOp.getBody();
+ initSize = forallOp.getOutputs().size();
+ rank = forallOp.getRank();
+ }
+
+ if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+ return rewriter.notifyMatchFailure(
+ oldLoopOp, "containing loop op should either yield just one value or "
+ "have the consumer op as its first user");
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+
+ // 2. Check consumer is not using scf loop's output as init.
+ auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+ SmallVector<Value> dpsInits =
+ llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+ if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer op taking the result of scf.for as init is not supported");
+ }
+ newOuts.append(dpsInits);
+
+ Location loc = oldLoopOp->getLoc();
+
+ // 3. Create new scf loop op.
+ rewriter.setInsertionPoint(consumerOp);
+ Operation *newLoopOp = nullptr;
+ Block *newLoopBody = nullptr;
+ if (isInsertSliceOp) {
+ auto forOp = cast<scf::ForOp>(oldLoopOp);
+ auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+ forOp.getUpperBound(),
+ forOp.getStep(), newOuts);
+ newLoopOp = newForOp;
+ newLoopBody = newForOp.getBody();
+ } else {
+ auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+ auto newForallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+ newLoopOp = newForallOp;
+ rewriter.eraseOp(newForallOp.getTerminator());
+ newLoopBody = newForallOp.getBody();
+ }
+
+ // 4. Move the loop body to the new op.
+ unsigned oldNumArguments = oldLoopBody->getNumArguments();
+ rewriter.mergeBlocks(oldLoopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(oldNumArguments));
+
+ // 5.a. Clone consumer after the cloned
+ // tensor.insert_slice/parallel_insert_slice op.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ auto newForOpBlockArgsForConsumerDest =
+ newLoopBody->getArguments().drop_front(oldNumArguments);
+ auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+ rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+ // 5.b. Replace all uses of the loop result with the result of the cloned
+ // tensor.insert_slice/parallel_insert_slice.
+ OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+ rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+ if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
+ operandToReplace.set(sliceOp.getResult());
+ } else if (auto sliceOp =
+ dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+ operandToReplace.set(sliceOp.getSource());
----------------
qedawkins wrote:
Ok I'll defer to you then, but I think it would be good to document this implicit dependence on `extract_slice` on the SCF api.
https://github.com/llvm/llvm-project/pull/88712
More information about the Mlir-commits
mailing list