[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Mon Apr 22 04:57:18 PDT 2024
================
@@ -1100,6 +1102,459 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
replacements};
}
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the passed value has only one user.
+/// In case the defining operation is a tensor.insert_slice, it checks if the
+/// user is scf.yield.
+static LogicalResult checkAssumptionForFusingConsumer(Value result) {
+ Value::use_range uses = result.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+ return failure();
+ }
+ OpOperand &operandUse = (*uses.begin());
+ Operation *userOp = operandUse.getOwner();
+ if (!isa<scf::YieldOp>(userOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Expected scf.yield to be the only user, but got -> "
+ << (*userOp));
+ return failure();
+ }
+ return success();
+}
+
+/// Fetch the first untiled consumer of a scf.for's result which is yielded by
+/// a tensor.insert_slice. This function makes the following assumptions :-
+/// 1. tensor.insert_slice has scf.yield as its only user.
+/// 2. scf.for's corresponding result has only one use.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+ Value sliceResult = candidateSliceOp.getResult();
+ if (failed(checkAssumptionForFusingConsumer(candidateSliceOp.getResult()))) {
+ return failure();
+ }
+ // Step 1. Fetch the corresponding output.
+ OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+ unsigned resultNumber = yieldOpOperand.getOperandNumber();
+ // Step 2. Check containing op is scf.for.
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ auto forOp = dyn_cast<scf::ForOp>(containingOp);
+ if (!forOp) {
+ return failure();
+ }
+ Value resultingValue = forOp->getResult(resultNumber);
+
+ // Step 3. Check resulting value of scf.for has exactly one use.
+ if (!llvm::hasSingleElement(resultingValue.getUses())) {
+ return failure();
+ }
+
+ // Step 4. Get uses.
+ OpOperand &operand = (*resultingValue.getUses().begin());
+ Operation *consumerOp = operand.getOwner();
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp)) {
+ return failure();
+ }
+ return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+ tensor::InsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of scf.for for the result yielded by
+ // tensor.insert_slice.
+ FailureOr<OpOperand *> consumerOpOperand =
+ getUntiledConsumerFromSlice(candidateSliceOp);
+ if (failed(consumerOpOperand)) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "could not fetch consumer to fuse");
+ }
+ Operation *consumerOp = (*consumerOpOperand)->getOwner();
+ unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
+ unsigned resultNumber =
+ cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
+
+ auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(candidateSliceOp);
+
+ auto dstOp = static_cast<DestinationStyleOpInterface>(consumerOp);
+ // 2. Check consumer is not using scf.for's output as init.
+ SmallVector<Value> dpsInits =
+ llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+ if (llvm::is_contained(dpsInits, forOp.getResult(0))) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer op taking the result of scf.for as init is not supported");
+ }
+
+ Location loc = forOp.getLoc();
+ SmallVector<Value> newOuts(forOp.getInits());
+ newOuts.append(dpsInits);
+
+ // 3. Create new scf.for op.
+ rewriter.setInsertionPoint(consumerOp);
+ auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+ forOp.getUpperBound(),
+ forOp.getStep(), newOuts);
+ // 4. Move the loop body to the new op.
+ Block *loopBody = forOp.getBody();
+ Block *newLoopBody = newforOp.getBody();
+ rewriter.mergeBlocks(
+ loopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+ // 5. Clone tensor.insert_slice after original tensor.insert_slice.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ SmallVector<Value> candidateSliceOpOperands =
+ llvm::to_vector(candidateSliceOp->getOperands());
+ tensor::InsertSliceOp clonedCandidateSliceOp =
+ mlir::clone(rewriter, candidateSliceOp,
+ candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+
+ // 6.a. Clone consumer after the cloned tensor.insert_slice op.
+ rewriter.setInsertionPointAfter(clonedCandidateSliceOp);
+ SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector(
+ newLoopBody->getArguments().drop_front(loopBody->getNumArguments()),
+ [](BlockArgument b) -> Value { return b; });
+ auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+ rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+ // 6.b. Replace all uses of the loop result with the result of the cloned
+ // tensor.insert_slice.
+ rewriter.replaceUsesWithIf(forOp.getResult(resultNumber),
+ clonedCandidateSliceOp.getResult(),
+ [&](OpOperand &operand) {
+ return operand.getOwner() == clonedConsumerOp;
+ });
+
+ // 7 - Perform tiling of the cloned consumer.
+ rewriter.setInsertionPointAfter(clonedConsumerOp);
+ FailureOr<TilingResult> tileAndFuseResult =
+ tensor::replaceInsertSliceWithTiledConsumer(
+ rewriter,
+ cast<OffsetSizeAndStrideOpInterface>(
+ clonedCandidateSliceOp.getOperation()),
+ clonedConsumerOp->getOpOperand(operandNumber));
+ if (failed(tileAndFuseResult)) {
+ return rewriter.notifyMatchFailure(clonedConsumerOp,
+ "failed to tile consumer op: ");
+ }
+ assert(!(tileAndFuseResult->tiledOps.empty()) && "tiled consumer not found");
+
+ // 8 - Extract offset/sizes/strides required to create the tensor.insert_slice
+ // for each result of the consumer.
+ SmallVector<OpFoldResult> offsets = clonedCandidateSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = clonedCandidateSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = clonedCandidateSliceOp.getMixedStrides();
+ // 9. Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ clonedCandidateSliceOp, "containingOp's result yield with stride");
+ }
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ // 10. Try to get iter domain position from input position.
+ rewriter.setInsertionPointAfter(clonedConsumerOp);
+ if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ clonedConsumerOp, "can't get iter domain position from input position");
+ }
+
+ // 11. Try to get all containing op result's position from iter domain
+ // position.
+ llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+ llvm::SmallVector<OpFoldResult>>>
+ resultPositions(clonedConsumerOp->getNumResults());
+ for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+ if (failed(clonedConsumerOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultPositions[idx].first, resultPositions[idx].second))) {
+ return rewriter.notifyMatchFailure(
+ clonedConsumerOp,
+ "can't get result domain position from iter domain position");
+ }
+ }
+
+ // 12. Fix terminator.
+ scf::YieldOp oldTerminatorOp =
+ static_cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
+ SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+ rewriter.setInsertionPointAfter(oldTerminatorOp);
+ auto bbArgs = newforOp.getBody()->getArguments();
+ for (auto [idx, v] :
+ llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+ rewriter.getIndexAttr(1));
+ Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ clonedCandidateSliceOp->getLoc(), v,
+ bbArgs[1 + forOp.getInits().size() + idx], resultPositions[idx].first,
+ resultPositions[idx].second, strides);
+ newYieldOperands.push_back(newInsertSliceOp);
+ }
+ rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+ rewriter.eraseOp(oldTerminatorOp);
+
+ // 13. Replace the result of scf.for and consumer op.
+ for (auto result : llvm::enumerate(forOp.getResults())) {
+ rewriter.replaceAllUsesWith(result.value(),
+ newforOp->getResult(result.index()));
+ }
+
+ for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+ rewriter.replaceAllUsesWith(
+ consumerResult.value(),
+ newforOp->getResult(forOp.getInits().size() + consumerResult.index()));
+ }
+
+ rewriter.replaceOp(candidateSliceOp, clonedCandidateSliceOp);
+
+ // 14. Need to erase the old scf.for and the cloned consumer op.
+ rewriter.eraseOp(forOp);
+ rewriter.eraseOp(clonedConsumerOp);
+
+ return scf::SCFFuseConsumerOfSliceResult{
+ consumerOp, tileAndFuseResult->tiledOps[0], {}};
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // Step 1. Fetch the corresponding output
+ Value sliceDest = candidateSliceOp.getDest();
+ auto iterArg = dyn_cast<BlockArgument>(sliceDest);
+ Operation *containingOp = iterArg.getOwner()->getParentOp();
+ // Step 2. Check that the containing op is scf.forall.
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ return failure();
+ }
+ unsigned resultNumber = 0;
+ for (BlockArgument val : forallOp.getRegionOutArgs()) {
+ if (val == iterArg) {
+ break;
+ }
+ resultNumber++;
+ }
+ Value resultingValue = forallOp->getResult(resultNumber);
+ // Step 3. Check resulting value of scf.forall has exactly one use.
+ Value::use_range uses = resultingValue.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ return failure();
+ }
+
+ // Step 4. Get uses.
+ OpOperand &operand = (*resultingValue.getUses().begin());
+ Operation *consumerOp = operand.getOwner();
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp)) {
+ return failure();
+ }
+ return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+ RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of the dest.
+ FailureOr<OpOperand *> consumerOpOperand =
+ getUntiledConsumerFromSlice(candidateSliceOp);
+ if (failed(consumerOpOperand)) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "could not fetch consumer to fuse");
+ }
+ Operation *consumerOp = (*consumerOpOperand)->getOwner();
+ unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
+ unsigned resultNumber =
+ cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ // Using candidateSliceOp->getParentOp() because we have the following case :-
+ // scf.forall.in_parallel {
+ // tensor.parallel_insert_slice ...
+ // }
+ rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+ Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+ auto forallOp = static_cast<scf::ForallOp>(containingOp);
+
+ auto dstOp = static_cast<DestinationStyleOpInterface>(consumerOp);
+ // 2. Check consumer is not using scf.forall's output as init.
+ SmallVector<Value> dpsInits =
+ llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+ if (llvm::is_contained(dpsInits, forallOp.getResult(resultNumber))) {
+ return rewriter.notifyMatchFailure(
+ consumerOp,
+ "consumer op taking the result of scf.forall as init is not supported");
+ }
+
+ Location loc = forallOp.getLoc();
+ // 3. Create new scf.forall op.
+ SmallVector<Value> newOuts(forallOp.getOutputs());
+ newOuts.append(dpsInits);
+ rewriter.setInsertionPoint(consumerOp);
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+
+ // 4. Move the loop body to the new op.
+ rewriter.eraseOp(newforallOp.getTerminator());
+ Block *loopBody = forallOp.getBody();
+ Block *newLoopBody = newforallOp.getBody();
+ rewriter.mergeBlocks(
+ loopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+ // 5. Clone tensor.parallel_insert_slice after the original
+ // tensor.parallel_insert_slice.
+ rewriter.setInsertionPointAfter(candidateSliceOp);
+ SmallVector<Value> candidateSliceOpOperands =
+ llvm::to_vector(candidateSliceOp->getOperands());
+ tensor::ParallelInsertSliceOp clonedCandidateSliceOp =
+ mlir::clone(rewriter, candidateSliceOp,
+ candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+
+ // 6.a. Clone the consumer after the cloned tensor.parallel_insert_slice.
+ rewriter.setInsertionPointAfter(clonedCandidateSliceOp);
+ SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector(
+ newLoopBody->getArguments().drop_front(loopBody->getNumArguments()),
+ [](BlockArgument b) -> Value { return b; });
----------------
ftynse wrote:
Do we really need to materialize a vector here? `cloneOpAndUpdateDestinationArgs` is implicitly constructible form an `ArrayRef<BlockArgument>`. If we do, we may be missing an implicit conversion support somewhere that we should add (in a separate patch).
https://github.com/llvm/llvm-project/pull/88712
More information about the Mlir-commits
mailing list