[Mlir-commits] [mlir] [DO NOT REVIEW][mlir][scf] Extend consumer fuse to nested loop structure (PR #94460)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 5 04:13:34 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (Yun-Fly)
<details>
<summary>Changes</summary>
The CI failure of #<!-- -->94190 could not be reproduced after several attempts on local, just create this mirror PR particularly for debugging purpose.
Please **DO NOT REVIEW** and **ignore** all changes for this.
Once root cause is figured out, I will close this PR.
Sorry for disturb again. Thanks.
---
Patch is 47.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94460.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+517-221)
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+96)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..488a1c6585790 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -1103,98 +1104,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
// tileAndFuseConsumerUsingSCF implementation.
//===----------------------------------------------------------------------===//
-/// A utility function that checks whether the only use of the result of a
-/// tensor.insert_slice op is in a scf.yield op.
-static LogicalResult
-checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
- Value result = candidateSliceOp.getResult();
- Value::use_range uses = result.getUses();
- if (!llvm::hasSingleElement(uses)) {
- LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
- return failure();
- }
- OpOperand &operandUse = (*uses.begin());
- Operation *userOp = operandUse.getOwner();
- if (!isa<scf::YieldOp>(userOp)) {
- LLVM_DEBUG(llvm::dbgs()
- << "Expected scf.yield to be the only user, but got -> "
- << (*userOp));
- return failure();
- }
- if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
- LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
- "be in the same block\n");
- return failure();
- }
- return success();
-}
-
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
-static FailureOr<OpOperand *> getConsumerFromUses(Value val,
- Block *containingOpBlock) {
- // Step 1. Check that the value has exactly one use.
- if (!llvm::hasSingleElement(val.getUses()))
- return failure();
- // Step 2. Get uses.
- OpOperand &operand = (*val.getUses().begin());
- Operation *consumerOp = operand.getOwner();
- // TODO: We have to init result of consumer before scf.for, use
- // DestinationStyleOpInterface to get result shape from init for now.
- // Add support for other op such as op has InferTypeOpInterface.
- if (!isa<TilingInterface>(consumerOp) ||
- !isa<DestinationStyleOpInterface>(consumerOp))
- return failure();
- if (containingOpBlock != consumerOp->getBlock())
- return failure();
- return &operand;
-}
-
-/// Fetch the untiled consumer of a scf.for's result which is yielded by a
-/// tensor.insert_slice. This function makes the following assumptions :
-/// 1. tensor.insert_slice has scf.yield as its only user.
-/// 2. scf.for's corresponding result has only one use.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
- if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
- return failure();
- Value sliceResult = candidateSliceOp.getResult();
- // Step 1. Fetch the corresponding output.
- OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
- unsigned resultNumber = yieldOpOperand.getOperandNumber();
- // Step 2. Check containing op is scf.for.
- Operation *containingOp = candidateSliceOp->getParentOp();
- auto forOp = dyn_cast<scf::ForOp>(containingOp);
- if (!forOp)
- return failure();
- Value resultingValue = forOp->getResult(resultNumber);
-
- return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
-/// Fetch the first untiled consumer of a scf.forall's result which is yielded
-/// by a tensor.parallel_insert_slice.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
- // Step 1. Fetch the corresponding output
- Value sliceDest = candidateSliceOp.getDest();
- auto iterArg = dyn_cast<BlockArgument>(sliceDest);
- if (!iterArg)
- return failure();
- Operation *containingOp = iterArg.getOwner()->getParentOp();
- if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
- return failure();
- // Step 2. Check that the containing op is scf.forall.
- auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
- if (!forallOp)
- return failure();
- Value resultingValue =
- forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
-
- return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
/// This utility currently checks whether the loop either :-
/// 1. Yields exactly one result.
/// 2. Has consumer op as its first user and other users to be in the same
@@ -1220,31 +1129,117 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
return success();
}
-/// A utility to fetch an untiled consumer of
-/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
- if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(insertSlice);
- } else if (auto parallelInsertSlice =
- dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(parallelInsertSlice);
- } else {
+/// Traverse and collect all outer loops of given sliceOp, sorted by
+/// outer-to-inner. If `untilLoop` found, stop walk through in advance.
+static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
+ Operation *sliceOp,
+ std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
+ assert(isa<OffsetSizeAndStrideOpInterface>(sliceOp));
+ SmallVector<LoopLikeOpInterface> outerLoops;
+ auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
+ while (forOp) {
+ outerLoops.push_back(forOp);
+ if (untilLoop.has_value() && *untilLoop == forOp)
+ break;
+ forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+ }
+ return {outerLoops.rbegin(), outerLoops.rend()};
+}
+
+/// Get the result of top-level loop which yields the target InsertSliceOp. E.g
+/// ```
+/// %1 = scf.for
+/// %2 = scf.for
+/// %3 = scf.for
+/// ...
+/// %4 = insert
+/// yield %4
+/// %5 = insert %3
+/// yield %5
+/// yield %2
+/// ```
+/// @param targetSliceOp: %4 = insert
+/// @return first: resultValue: %1
+/// second: Collected insertSliceOp List during walk including
+/// targetSliceOp: %4 = insert and %5 = insert %3
+static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+getResultOfTopLevelLoopYieldInsertSliceOp(Operation *targetSliceOp,
+ int curDepth = 0, int maxDepth = 5) {
+ assert(isa<OffsetSizeAndStrideOpInterface>(targetSliceOp));
+ // Control recursive time in avoid of stack overflow
+ if (curDepth > maxDepth)
+ return failure();
+
+ SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
+ candidateSliceOpList.push_back(
+ cast<OffsetSizeAndStrideOpInterface>(targetSliceOp));
+ Value resultOfLoop;
+ if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(targetSliceOp)) {
+ Value destValue = sliceOp.getDest();
+ auto iterArg = cast<BlockArgument>(destValue);
+ auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
+ if (!forallOp)
+ return failure();
+ resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+ } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(targetSliceOp)) {
+ Value resultValue = sliceOp.getResult();
+ for (auto &useOperand : resultValue.getUses()) {
+ if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+ if (llvm::detail::isPresent(resultOfLoop))
+ return failure();
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+ if (!forOp)
+ return failure();
+ resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+ }
+ }
+ }
+
+ if (!llvm::detail::isPresent(resultOfLoop))
return failure();
+
+ while (true) {
+ bool walkThroughOuterLoop = false;
+ for (OpOperand &useOperand : resultOfLoop.getUses()) {
+ if (auto sliceOp =
+ dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
+ FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+ resultAndSliceOpsPair = getResultOfTopLevelLoopYieldInsertSliceOp(
+ sliceOp, curDepth + 1);
+ if (failed(resultAndSliceOpsPair))
+ return failure();
+ candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
+ (*resultAndSliceOpsPair).second.end());
+ return std::make_pair((*resultAndSliceOpsPair).first,
+ candidateSliceOpList);
+ } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+ // walk through outer loop
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+ if (!forOp)
+ return failure();
+ resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+ walkThroughOuterLoop = true;
+ break;
+ }
+ }
+ if (!walkThroughOuterLoop)
+ break;
}
+ return std::make_pair(resultOfLoop, candidateSliceOpList);
}
/// After fusing consumer into scf.for we want to modify the scf.yield operation
/// to reflect the same by returning the values yielded by the tiled consumer.
static void
fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
- TilingResult &tilingResult,
- ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+ ResultRange tilingResult,
+ ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> resultSizes,
ArrayRef<BlockArgument> bbArgs) {
scf::YieldOp oldTerminatorOp =
cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
unsigned totalOldResults = oldTerminatorOp->getNumResults();
- unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+ unsigned totalTiledResults = tilingResult.size();
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(totalOldResults + totalTiledResults);
for (auto oldResult : oldTerminatorOp.getResults()) {
@@ -1253,8 +1248,7 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
rewriter.setInsertionPointAfter(oldTerminatorOp);
Location loc = newForOp.getLoc();
for (auto [tiledResult, bbArg, resultOffset, resultSize] :
- llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
- resultOffsets, resultSizes)) {
+ llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> strides(resultOffset.size(),
rewriter.getIndexAttr(1));
Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
@@ -1269,16 +1263,16 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
/// values by the tiled consumer within scf.forall.in_parallel region.
static void
fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
- SmallVector<Value> tiledResults,
- ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+ ResultRange tilingResult,
+ ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> resultSizes,
ArrayRef<BlockArgument> bbArgs) {
scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
Location firstYieldOpLoc =
(*(newTerminatorOp.getYieldingOps().begin())).getLoc();
for (auto [tiledResult, bbArg, resultOffset, resultSize] :
- llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+ llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> strides(resultOffset.size(),
rewriter.getIndexAttr(1));
rewriter.create<tensor::ParallelInsertSliceOp>(
@@ -1286,6 +1280,176 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
}
}
+/// If the top level loop of nested loop structure is scf.forall, need to create
+/// additional tensor.extract_slice for its new appended `shared_outs` in order
+/// to pass correct local memory for inner loops. E.g.
+///
+/// scf.forall shared_outs(%o1=..., %o2=...) {
+/// %local_o1 = extract_slice %o1
+/// // fix new appended `shared_out` %o2
+/// %local_o2 = extract_slice %o2
+/// scf.for init_args(%init1=%local_o1, %init2=%local_o2) {
+/// ...
+/// }
+/// ...
+/// }
+static SmallVector<tensor::ExtractSliceOp> fixLoopInitFromSharedOutSCFForall(
+ RewriterBase &rewriter, Operation *loop, ValueRange newSharedOuts,
+ ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
+ rewriter.setInsertionPoint(loop);
+ Location loc = loop->getLoc();
+ // create new ExtractOps for newInits from scf.forall
+ SmallVector<tensor::ExtractSliceOp> newExtractOps;
+ newExtractOps.reserve(resultOffsets.size());
+ for (auto [bbArg, offset, sizes] :
+ llvm::zip_equal(newSharedOuts, resultOffsets, resultSizes)) {
+ SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+ auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+ loc, bbArg, offset, sizes, strides);
+ newExtractOps.push_back(newExtractOp);
+ }
+ return newExtractOps;
+}
+
+/// If outerMost loop of nested loop structure is `scf.forall`, need to deal
+/// with DpsInit of tiled consumer
+static void
+fixDpsInitsOfTiledConsumer(RewriterBase &rewriter, Operation *tiledConsumer,
+ ArrayRef<BlockArgument> bbArgs,
+ ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
+ rewriter.setInsertionPoint(tiledConsumer);
+ Location loc = tiledConsumer->getLoc();
+ for (auto &&[bbArg, offset, sizes, dpsInit] :
+ llvm::zip_equal(bbArgs, resultOffsets, resultSizes,
+ cast<DestinationStyleOpInterface>(tiledConsumer)
+ .getDpsInitsMutable())) {
+ SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+ auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+ loc, bbArg, offset, sizes, strides);
+ dpsInit.set(newExtractOp.getResult());
+ }
+}
+
+/// Compute all results tile by given SliceOp along operand
+static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
+ RewriterBase &rewriter, TilingInterface tilableOp, unsigned operandNumber,
+ OffsetSizeAndStrideOpInterface ossSliceOp,
+ SmallVector<SmallVector<OpFoldResult>> &allResultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &allResultSizes) {
+ // 1. Check all stride all 1
+ if (llvm::any_of(ossSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(ossSliceOp, "ossSliceOp has stride");
+ }
+ // 2. Compute iteration domain tile from the input position
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ if (failed(tilableOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, ossSliceOp.getMixedOffsets(),
+ ossSliceOp.getMixedSizes(), iterDomainOffsets, iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ tilableOp, "can't get iter domain position from input position");
+ }
+ unsigned totalNumResultsOfConsumer = tilableOp->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+ totalNumResultsOfConsumer);
+ SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+ // 3. Compute result Tile by resultNumber
+ for (auto [idx, v] : llvm::enumerate(tilableOp->getResults())) {
+ if (failed(tilableOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultOffsets[idx], resultSizes[idx]))) {
+ return rewriter.notifyMatchFailure(
+ tilableOp,
+ "can't get result domain position from iter domain position");
+ }
+ }
+ allResultOffsets = resultOffsets;
+ allResultSizes = resultSizes;
+ return success();
+}
+
+/// Considering multi-level tensor.*SliceOp maybe based on different
+/// coordination, this utility computes the real OFFSET coordinated on ROOT
+/// SliceOp. E.g
+/// %0 = insert_slice %1 into %2[OFFSET1] [SIZE1]
+/// %3 = insert_slice %4 into %5[OFFSET2] [SIZE2]
+///
+/// where the coordination can be illustrated as follow:
+///
+/// %3 ----------------------------------
+/// | | |
+/// | OFFSET2 | OFFSET1 |
+/// | ------ %0 |
+/// | |
+/// | |
+/// |------------------ %1 ------ |
+/// | | SIZE1 |
+/// | | |
+/// | | |
+/// | | ------- |
+/// |
+///
+/// The real OFFSET of %1 coordinated on %3 is actually `OFFSET1` + `OFFSET2`
+static FailureOr<SmallVector<OpFoldResult>>
+computeRealOffsetsCoordinatedRootSliceOp(
+ RewriterBase &rewriter, Location loc,
+ OffsetSizeAndStrideOpInterface candidateSliceOp,
+ MutableArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceOpList) {
+ if (llvm::any_of(candidateSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "candidateSliceOp has stride");
+ }
+ SmallVector<OpFoldResult> realOffsets = candidateSliceOp.getMixedOffsets();
+ // Real offsets equals to accumulative offsets of outer candidates
+ for (auto iter = candidateSliceOpList.rbegin(); *iter != candidateSliceOp;
+ iter++) {
+ // assert each outer candidate slice has no stride
+ if (llvm::any_of(iter->getMixedStrides(), [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return failure();
+ }
+ for (auto &&[ofr1, ofr2] :
+ llvm::zip_equal(realOffsets, iter->getMixedOffsets())) {
+ using AVE = affine::AffineValueExpr;
+ affine::AffineBuilder ab(rewriter, loc);
+ AffineExpr dim0, dim1, sym;
+ bindDims(rewriter.getContext(), dim0, dim1);
+ bindSymbols(rewriter.getContext(), sym);
+ auto aveOffset1 = AVE(dim0).bind(ofr1), aveOffset2 = AVE(dim1).bind(ofr2);
+ ofr1 = ab.add(aveOffset1, aveOffset2);
+ }
+ }
+ return realOffsets;
+}
+
+/// Get the first tilable user of given Value and check its domination at the
+/// same time
+static FailureOr<OpOperand *>
+getTilableConsumerOperandFirstUseVal(Value val, Operation *loopOp) {
+ for (auto &useOfval : val.getUses()) {
+ Operation *consumerOp = useOfval.getOwner();
+ // 1. Check whether consumerOp is tilable
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp))
+ continue;
+ // 2. Check stay in same block with loopOp
+ if (loopOp->getBlock() != consumerOp->getBlock())
+ continue;
+ // 3. Check no other user before it
+ if (failed(checkAssumptionForLoop(loopOp, consumerOp))) {
+ continue;
+ }
+ return &useOfval;
+ }
+ return failure();
+}
+
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1297,10 +1461,28 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
bool isInsertSliceOp = isa<tensor...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/94460
More information about the Mlir-commits
mailing list