[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 19 00:09:22 PDT 2024
================
@@ -1100,6 +1101,428 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
replacements};
}
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// In the following function `source` is the source operand of
+/// tensor.insert_slice op. We traverse through the use-def chain of the same
+/// through the containing scf.for to fetch the first untiled consumer.
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFFor(OpOperand &source,
+ unsigned &operandNumber) {
+ // Step 1. Fetch the corresponding output
+ // TODO(avarma): Make it generic for multiple values yielding scf.for.
+ unsigned yieldOperandNumber = source.getOperandNumber();
+ Value resultingValue =
+ source.getOwner()->getParentOp()->getResult(yieldOperandNumber);
+
+ // Step 3. Get users.
+ std::optional<OpOperand *> destinationIterArg;
+ Operation *untiledConsumer;
+ for (Operation *user : resultingValue.getUsers()) {
+ // TODO(avarma): Address the case where the consumer op itself can return
+ // more than one result.
+ for (Value operand : user->getOperands()) {
+ if (operand == resultingValue) {
+ untiledConsumer = user;
+ break;
+ }
+ operandNumber++;
+ }
+ break;
+ }
+ return {untiledConsumer, destinationIterArg};
+}
+
+static bool checkAssumptionForFusingConsumer(Operation *op) {
+ Value result = op->getResult(0);
+ Value::user_range users = result.getUsers();
+ if (std::distance(users.begin(), users.end()) != 1) {
+ LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+ return false;
+ }
+ if (isa<tensor::InsertSliceOp>(op) && !isa<scf::YieldOp>(*users.begin())) {
+ LLVM_DEBUG(llvm::dbgs() << "Expected scf.yield to be the only user\n");
+ return false;
+ }
+ return true;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+ tensor::InsertSliceOp candidateSliceOp) {
+ // ASSUMING THAT YIELD OP IS ONLY YIELDING JUST ONE VALUE.
+ if (!checkAssumptionForFusingConsumer(candidateSliceOp)) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "needs only scf.yield as its user");
+ }
+ // 1. Get the consumer of the source.
+ unsigned operandNumber = 0;
+ auto [consumerOp, destinationInitArg] = getUntiledConsumerFromSliceDestSCFFor(
+ candidateSliceOp->getOpOperand(0), operandNumber);
+ if (!consumerOp)
+ return failure();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(candidateSliceOp);
+
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ // Check consumer has tiling interface.
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ return rewriter.notifyMatchFailure(consumerOp,
+ "consumer is not a TileableInterface");
+ }
+
+ // Check containing op is "scf::ForOp".
+ auto forOp = dyn_cast<scf::ForOp>(containingOp);
+ if (!forOp) {
+ return rewriter.notifyMatchFailure(containingOp,
+ "containing op is not a scf.for");
+ }
+
+ // Check containingOp has exactly one use.
+ assert(forOp.getResults().size() == 1 &&
+ "expect exactly one result of the containing op");
+ if (!checkAssumptionForFusingConsumer(forOp)) {
+ return rewriter.notifyMatchFailure(forOp, "scf.for has more than 1 uses");
+ }
+ Value bridge(nullptr);
+ SmallVector<unsigned> operandNums;
+ for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+ if (opd.getDefiningOp() == containingOp) {
+ operandNums.push_back(idx);
+ if (!bridge) {
+ bridge = opd;
+ } else if (bridge != opd) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer's operand use more than one containingOp's result");
+ }
+ }
+ }
+
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ // Check consumer has DestinationStyleOpInterface.
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp) {
+ return rewriter.notifyMatchFailure(
+ consumerOp, "consumer op should have destination style op interface");
+ }
+
+ // Check consumer is not using scf.for's output as init.
+ SmallVector<Value> dpsInits = llvm::to_vector<4>(
+ llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+ if (llvm::is_contained(dpsInits, forOp.getResult(0))) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer op taking the result of scf.for as init is not supported");
+ }
+
+ Location loc = forOp.getLoc();
+ SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+ // Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "containingOp's result yield with stride");
+ }
+
+ SmallVector<Value> newOuts(forOp.getInits());
+ newOuts.append(dpsInits);
+
+ // Create new scf.for op.
+ rewriter.setInsertionPoint(consumerOp);
+ auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+ forOp.getUpperBound(),
+ forOp.getStep(), newOuts);
+ // Move the loop body to the new op.
+ Block *loopBody = forOp.getBody();
+ Block *newLoopBody = newforOp.getBody();
+ rewriter.mergeBlocks(
+ loopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+ // Clone the consumer after the insert_slice.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ SmallVector<Value> newForOpBlockArgsForConsumerDest;
+ for (unsigned i = loopBody->getNumArguments(),
+ n = newLoopBody->getArguments().size();
+ i < n; i++) {
+ newForOpBlockArgsForConsumerDest.push_back(newLoopBody->getArgument(i));
+ }
+ auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+ rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+ // Replace scf.for result's use in the consumer with insert_slice result.
+ rewriter.replaceAllUsesWith(forOp.getResult(0), candidateSliceOp.getResult());
+
+ // Generate the tiled implementation of the consumer of the source.
+ rewriter.setInsertionPoint(candidateSliceOp);
+ FailureOr<TilingResult> tileAndFuseResult =
+ tensor::replaceInsertSliceWithTiledConsumer(
+ rewriter, candidateSliceOp,
+ clonedConsumerOp->getOpOperand(operandNumber));
+ if (failed(tileAndFuseResult)) {
+ return rewriter.notifyMatchFailure(tileableConsumer,
+ "failed to tile consumer op: ");
+ }
+
+ // Update the source of the candidateSlice to be the cloned consumer.
+ SmallVector<Value> candidateSliceOpOperands =
+ llvm::to_vector(candidateSliceOp->getOperands());
+ candidateSliceOpOperands[0] = tileAndFuseResult->tiledValues[0];
+ tensor::InsertSliceOp clonedCandidateSliceOp =
+ mlir::clone(rewriter, candidateSliceOp,
+ candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+ auto bbArgs = newforOp.getBody()->getArguments();
+ clonedCandidateSliceOp->getOpOperands()[1].set(
+ bbArgs[1 + forOp.getInits().size() + 0]);
+
+ rewriter.replaceAllUsesWith(candidateSliceOp, candidateSliceOp.getSource());
+ rewriter.eraseOp(clonedConsumerOp);
+
+ // Fix terminator.
+ scf::YieldOp oldTerminatorOp =
+ static_cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
+
+ SmallVector<Value> newYieldOperands;
+ for (Value val : oldTerminatorOp.getResults()) {
+ if (val == candidateSliceOp.getSource()) {
+ newYieldOperands.push_back(candidateSliceOp.getResult());
+ } else {
+ newYieldOperands.push_back(val);
+ }
+ }
+ newYieldOperands.push_back(clonedCandidateSliceOp.getResult());
+ rewriter.setInsertionPointAfter(oldTerminatorOp);
+ rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+ rewriter.eraseOp(oldTerminatorOp);
+
+ // Replace the result of for and consumer op.
+ for (auto result : llvm::enumerate(forOp.getResults())) {
+ rewriter.replaceAllUsesWith(result.value(),
+ newforOp->getResult(result.index()));
+ }
+
+ for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+ rewriter.replaceAllUsesWith(
+ consumerResult.value(),
+ newforOp->getResult(forOp.getInits().size() + consumerResult.index()));
+ }
+
+ // Need to erase the old for.
+ rewriter.eraseOp(forOp);
+ rewriter.eraseOp(consumerOp);
+
+ return scf::SCFFuseConsumerOfSliceResult{
+ consumerOp, tileAndFuseResult->tiledOps[0]->getResult(0), {}};
+}
+
+/// In the following function `dest` is the destination operand of
+/// tensor.parallel_insert_slice op. We traverse through the use-def chain of
+/// the same through the containing scf.forall to fetch the first untiled
+/// consumer.
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFForall(OpOperand *dest,
+ unsigned &operandNumber) {
+ // Step 1. Fetch the corresponding output
+ // TODO(avarma): Make it generic for multiple values yielding scf.forall.
+ auto iterArg = dyn_cast<BlockArgument>(dest->get());
+ Value resultingValue = iterArg.getOwner()->getParentOp()->getResult(0);
+
+ // Step 3. Get users.
+ std::optional<OpOperand *> destinationIterArg;
+ Operation *untiledConsumer;
+ for (Operation *user : resultingValue.getUsers()) {
+ // TODO(avarma): Address the case where the consumer op itself can return
+ // more than one result.
+ for (Value operand : user->getOperands()) {
+ if (operand == resultingValue) {
+ untiledConsumer = user;
+ break;
+ }
+ operandNumber++;
+ }
+ break;
+ }
+ return {untiledConsumer, destinationIterArg};
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+ RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of the dest.
+ unsigned operandNumber = 0;
+ auto [consumerOp, destinationInitArg] =
+ getUntiledConsumerFromSliceDestSCFForall(
+ &candidateSliceOp.getDestMutable(), operandNumber);
+ if (!consumerOp)
+ return failure();
+ OpBuilder::InsertionGuard g(rewriter);
+ // Using candidateSliceOp->getParentOp() because we have the following case :-
+ // scf.forall.in_parallel {
+ // tensor.parallel_insert_slice ...
+ // }
+ rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+ Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+ // Check consumer has tiling interface.
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ return rewriter.notifyMatchFailure(consumerOp,
+ "consumer is not a TileableInterface");
+ }
+
+ // Check containing op is "scf::ForallOp".
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ return rewriter.notifyMatchFailure(containingOp,
+ "containing op is not a scf.forall");
+ }
+
+ // Check consumer don't use more than one result of containingOp.
+ // Check containingOp has exactly one use.
+ assert(forallOp.getResults().size() == 1 &&
+ "expect exactly one result of the containing op");
+ if (!checkAssumptionForFusingConsumer(forallOp)) {
+ return rewriter.notifyMatchFailure(forallOp,
+ "scf.forall has more than 1 uses");
+ }
+
+ // TODO: We have to init result of consumer before scf.forall, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ // Check consumer has DestinationStyleOpInterface.
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp) {
+ return rewriter.notifyMatchFailure(
+ consumerOp, "consumer op should have destination style op interface");
+ }
+
+ // Check consumer doon't use scf.forall's output as init.
+ SmallVector<Value> dpsInits = llvm::to_vector<4>(
+ llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+ if (llvm::is_contained(dpsInits, forallOp.getResult(0))) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer op taking the result of scf.forall as init is not supported");
+ }
+
+ SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+
+ // Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "containingOp's result yield with stride");
+ }
+
+ Location loc = forallOp.getLoc();
+ // Create new scf.forall op.
+ SmallVector<Value> newOuts(forallOp.getOutputs());
+ newOuts.append(dpsInits);
+ rewriter.setInsertionPoint(consumerOp);
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+
+ // Move the loop body to the new op.
+ rewriter.eraseOp(newforallOp.getTerminator());
+ Block *loopBody = forallOp.getBody();
+ Block *newLoopBody = newforallOp.getBody();
+ rewriter.mergeBlocks(
+ loopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+ // Clone the consumer after the insert_slice.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ SmallVector<Value> newForOpBlockArgsForConsumerDest;
+ for (unsigned i = loopBody->getNumArguments(),
+ n = newLoopBody->getArguments().size();
+ i < n; i++) {
+ newForOpBlockArgsForConsumerDest.push_back(newLoopBody->getArgument(i));
+ }
+ auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+ rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+ // Replace scf.forall result's use in the consumer with parallel_insert_slice
+ // source.
+ rewriter.replaceAllUsesWith(forallOp.getResult(0),
+ candidateSliceOp.getSource());
+
+ // Generate the tiled implementation of the consumer of the source.
+ rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+ FailureOr<TilingResult> tileAndFuseResult =
+ tensor::replaceInsertSliceWithTiledConsumer(
+ rewriter, candidateSliceOp,
+ clonedConsumerOp->getOpOperand(operandNumber));
+ if (failed(tileAndFuseResult)) {
+ return rewriter.notifyMatchFailure(tileableConsumer,
+ "failed to tile consumer op: ");
+ }
+
+ // Update the source of the candidateSlice to be the cloned consumer.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ SmallVector<Value> candidateSliceOpOperands =
+ llvm::to_vector(candidateSliceOp->getOperands());
+ candidateSliceOpOperands[0] = tileAndFuseResult->tiledValues[0];
+ tensor::ParallelInsertSliceOp clonedCandidateSliceOp =
+ mlir::clone(rewriter, candidateSliceOp,
+ candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+ auto bbArgs = newforallOp.getBody()->getArguments();
+ clonedCandidateSliceOp->getOpOperands()[1].set(
+ bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + 0]);
+
+ rewriter.eraseOp(clonedConsumerOp);
+
+ // Replace the result of scf.forall and consumer op.
+ for (auto result : llvm::enumerate(forallOp.getResults())) {
+ rewriter.replaceAllUsesWith(result.value(),
+ newforallOp->getResult(result.index()));
+ }
+
+ for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+ rewriter.replaceAllUsesWith(
+ consumerResult.value(),
+ newforallOp->getResult(forallOp.getOutputs().size() +
+ consumerResult.index()));
+ }
+
+ // Need to erase the old scf.forall and consumer.
+ rewriter.eraseOp(forallOp);
----------------
MaheshRavishankar wrote:
Ok, that makes sense... we probably need a DCE on the `scf.forall` to remove the unused original result.
https://github.com/llvm/llvm-project/pull/88712
More information about the Mlir-commits
mailing list