[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jul 6 06:10:11 PDT 2024
https://github.com/Yun-Fly updated https://github.com/llvm/llvm-project/pull/94190
>From acc7af60495c4aff1ca7c18a48021a66bbf20bdf Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Sun, 2 Jun 2024 23:32:22 -0700
Subject: [PATCH 1/3] extend consumer fuse to nested scf loop
---
.../SCF/Transforms/TileUsingInterface.cpp | 731 ++++++++++++------
.../tile-and-fuse-consumer.mlir | 96 +++
2 files changed, 606 insertions(+), 221 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a1392813d6de3..8a7eac1d883f9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -1197,98 +1198,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
// tileAndFuseConsumerUsingSCF implementation.
//===----------------------------------------------------------------------===//
-/// A utility function that checks whether the only use of the result of a
-/// tensor.insert_slice op is in a scf.yield op.
-static LogicalResult
-checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
- Value result = candidateSliceOp.getResult();
- Value::use_range uses = result.getUses();
- if (!llvm::hasSingleElement(uses)) {
- LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
- return failure();
- }
- OpOperand &operandUse = (*uses.begin());
- Operation *userOp = operandUse.getOwner();
- if (!isa<scf::YieldOp>(userOp)) {
- LLVM_DEBUG(llvm::dbgs()
- << "Expected scf.yield to be the only user, but got -> "
- << (*userOp));
- return failure();
- }
- if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
- LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
- "be in the same block\n");
- return failure();
- }
- return success();
-}
-
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
-static FailureOr<OpOperand *> getConsumerFromUses(Value val,
- Block *containingOpBlock) {
- // Step 1. Check that the value has exactly one use.
- if (!llvm::hasSingleElement(val.getUses()))
- return failure();
- // Step 2. Get uses.
- OpOperand &operand = (*val.getUses().begin());
- Operation *consumerOp = operand.getOwner();
- // TODO: We have to init result of consumer before scf.for, use
- // DestinationStyleOpInterface to get result shape from init for now.
- // Add support for other op such as op has InferTypeOpInterface.
- if (!isa<TilingInterface>(consumerOp) ||
- !isa<DestinationStyleOpInterface>(consumerOp))
- return failure();
- if (containingOpBlock != consumerOp->getBlock())
- return failure();
- return &operand;
-}
-
-/// Fetch the untiled consumer of a scf.for's result which is yielded by a
-/// tensor.insert_slice. This function makes the following assumptions :
-/// 1. tensor.insert_slice has scf.yield as its only user.
-/// 2. scf.for's corresponding result has only one use.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
- if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
- return failure();
- Value sliceResult = candidateSliceOp.getResult();
- // Step 1. Fetch the corresponding output.
- OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
- unsigned resultNumber = yieldOpOperand.getOperandNumber();
- // Step 2. Check containing op is scf.for.
- Operation *containingOp = candidateSliceOp->getParentOp();
- auto forOp = dyn_cast<scf::ForOp>(containingOp);
- if (!forOp)
- return failure();
- Value resultingValue = forOp->getResult(resultNumber);
-
- return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
-/// Fetch the first untiled consumer of a scf.forall's result which is yielded
-/// by a tensor.parallel_insert_slice.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
- // Step 1. Fetch the corresponding output
- Value sliceDest = candidateSliceOp.getDest();
- auto iterArg = dyn_cast<BlockArgument>(sliceDest);
- if (!iterArg)
- return failure();
- Operation *containingOp = iterArg.getOwner()->getParentOp();
- if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
- return failure();
- // Step 2. Check that the containing op is scf.forall.
- auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
- if (!forallOp)
- return failure();
- Value resultingValue =
- forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
-
- return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
/// This utility currently checks whether the loop either :-
/// 1. Yields exactly one result.
/// 2. Has consumer op as its first user and other users to be in the same
@@ -1314,31 +1223,117 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
return success();
}
-/// A utility to fetch an untiled consumer of
-/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
- if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(insertSlice);
- } else if (auto parallelInsertSlice =
- dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(parallelInsertSlice);
- } else {
+/// Traverse and collect all outer loops of given sliceOp, sorted by
+/// outer-to-inner. If `untilLoop` found, stop walk through in advance.
+static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
+ Operation *sliceOp,
+ std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
+ assert(isa<OffsetSizeAndStrideOpInterface>(sliceOp));
+ SmallVector<LoopLikeOpInterface> outerLoops;
+ auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
+ while (forOp) {
+ outerLoops.push_back(forOp);
+ if (untilLoop.has_value() && *untilLoop == forOp)
+ break;
+ forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+ }
+ return {outerLoops.rbegin(), outerLoops.rend()};
+}
+
+/// Get the result of top-level loop which yields the target InsertSliceOp. E.g
+/// ```
+/// %1 = scf.for
+/// %2 = scf.for
+/// %3 = scf.for
+/// ...
+/// %4 = insert
+/// yield %4
+/// %5 = insert %3
+/// yield %5
+/// yield %2
+/// ```
+/// @param targetSliceOp: %4 = insert
+/// @return first: resultValue: %1
+/// second: Collected insertSliceOp List during walk including
+/// targetSliceOp: %4 = insert and %5 = insert %3
+static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+getResultOfTopLevelLoopYieldInsertSliceOp(Operation *targetSliceOp,
+ int curDepth = 0, int maxDepth = 5) {
+ assert(isa<OffsetSizeAndStrideOpInterface>(targetSliceOp));
+ // Control recursive time in avoid of stack overflow
+ if (curDepth > maxDepth)
+ return failure();
+
+ SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
+ candidateSliceOpList.push_back(
+ cast<OffsetSizeAndStrideOpInterface>(targetSliceOp));
+ Value resultOfLoop;
+ if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(targetSliceOp)) {
+ Value destValue = sliceOp.getDest();
+ auto iterArg = cast<BlockArgument>(destValue);
+ auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
+ if (!forallOp)
+ return failure();
+ resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+ } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(targetSliceOp)) {
+ Value resultValue = sliceOp.getResult();
+ for (auto &useOperand : resultValue.getUses()) {
+ if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+ if (llvm::detail::isPresent(resultOfLoop))
+ return failure();
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+ if (!forOp)
+ return failure();
+ resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+ }
+ }
+ }
+
+ if (!llvm::detail::isPresent(resultOfLoop))
return failure();
+
+ while (true) {
+ bool walkThroughOuterLoop = false;
+ for (OpOperand &useOperand : resultOfLoop.getUses()) {
+ if (auto sliceOp =
+ dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
+ FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+ resultAndSliceOpsPair = getResultOfTopLevelLoopYieldInsertSliceOp(
+ sliceOp, curDepth + 1);
+ if (failed(resultAndSliceOpsPair))
+ return failure();
+ candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
+ (*resultAndSliceOpsPair).second.end());
+ return std::make_pair((*resultAndSliceOpsPair).first,
+ candidateSliceOpList);
+ } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+ // walk through outer loop
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+ if (!forOp)
+ return failure();
+ resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+ walkThroughOuterLoop = true;
+ break;
+ }
+ }
+ if (!walkThroughOuterLoop)
+ break;
}
+ return std::make_pair(resultOfLoop, candidateSliceOpList);
}
/// After fusing consumer into scf.for we want to modify the scf.yield operation
/// to reflect the same by returning the values yielded by the tiled consumer.
static void
fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
- TilingResult &tilingResult,
- ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+ ResultRange tilingResult,
+ ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> resultSizes,
ArrayRef<BlockArgument> bbArgs) {
scf::YieldOp oldTerminatorOp =
cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
unsigned totalOldResults = oldTerminatorOp->getNumResults();
- unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+ unsigned totalTiledResults = tilingResult.size();
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(totalOldResults + totalTiledResults);
for (auto oldResult : oldTerminatorOp.getResults()) {
@@ -1347,8 +1342,7 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
rewriter.setInsertionPointAfter(oldTerminatorOp);
Location loc = newForOp.getLoc();
for (auto [tiledResult, bbArg, resultOffset, resultSize] :
- llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
- resultOffsets, resultSizes)) {
+ llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> strides(resultOffset.size(),
rewriter.getIndexAttr(1));
Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
@@ -1363,16 +1357,16 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
/// values by the tiled consumer within scf.forall.in_parallel region.
static void
fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
- SmallVector<Value> tiledResults,
- ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+ ResultRange tilingResult,
+ ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> resultSizes,
ArrayRef<BlockArgument> bbArgs) {
scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
Location firstYieldOpLoc =
(*(newTerminatorOp.getYieldingOps().begin())).getLoc();
for (auto [tiledResult, bbArg, resultOffset, resultSize] :
- llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+ llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> strides(resultOffset.size(),
rewriter.getIndexAttr(1));
rewriter.create<tensor::ParallelInsertSliceOp>(
@@ -1380,6 +1374,176 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
}
}
+/// If the top level loop of nested loop structure is scf.forall, need to create
+/// additional tensor.extract_slice for its new appended `shared_outs` in order
+/// to pass correct local memory for inner loops. E.g.
+///
+/// scf.forall shared_outs(%o1=..., %o2=...) {
+/// %local_o1 = extract_slice %o1
+/// // fix new appended `shared_out` %o2
+/// %local_o2 = extract_slice %o2
+/// scf.for init_args(%init1=%local_o1, %init2=%local_o2) {
+/// ...
+/// }
+/// ...
+/// }
+static SmallVector<tensor::ExtractSliceOp> fixLoopInitFromSharedOutSCFForall(
+ RewriterBase &rewriter, Operation *loop, ArrayRef<Value> newSharedOuts,
+ ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
+ rewriter.setInsertionPoint(loop);
+ Location loc = loop->getLoc();
+ // create new ExtractOps for newInits from scf.forall
+ SmallVector<tensor::ExtractSliceOp> newExtractOps;
+ newExtractOps.reserve(resultOffsets.size());
+ for (auto [bbArg, offset, sizes] :
+ llvm::zip_equal(newSharedOuts, resultOffsets, resultSizes)) {
+ SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+ auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+ loc, bbArg, offset, sizes, strides);
+ newExtractOps.push_back(newExtractOp);
+ }
+ return newExtractOps;
+}
+
+/// If outerMost loop of nested loop structure is `scf.forall`, need to deal
+/// with DpsInit of tiled consumer
+static void
+fixDpsInitsOfTiledConsumer(RewriterBase &rewriter, Operation *tiledConsumer,
+ ArrayRef<BlockArgument> bbArgs,
+ ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
+ rewriter.setInsertionPoint(tiledConsumer);
+ Location loc = tiledConsumer->getLoc();
+ for (auto &&[bbArg, offset, sizes, dpsInit] :
+ llvm::zip_equal(bbArgs, resultOffsets, resultSizes,
+ cast<DestinationStyleOpInterface>(tiledConsumer)
+ .getDpsInitsMutable())) {
+ SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+ auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+ loc, bbArg, offset, sizes, strides);
+ dpsInit.set(newExtractOp.getResult());
+ }
+}
+
+/// Compute all results tile by given SliceOp along operand
+static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
+ RewriterBase &rewriter, TilingInterface tilableOp, unsigned operandNumber,
+ OffsetSizeAndStrideOpInterface ossSliceOp,
+ SmallVector<SmallVector<OpFoldResult>> &allResultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &allResultSizes) {
+ // 1. Check all stride all 1
+ if (llvm::any_of(ossSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(ossSliceOp, "ossSliceOp has stride");
+ }
+ // 2. Compute iteration domain tile from the input position
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ if (failed(tilableOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, ossSliceOp.getMixedOffsets(),
+ ossSliceOp.getMixedSizes(), iterDomainOffsets, iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ tilableOp, "can't get iter domain position from input position");
+ }
+ unsigned totalNumResultsOfConsumer = tilableOp->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+ totalNumResultsOfConsumer);
+ SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+ // 3. Compute result Tile by resultNumber
+ for (auto [idx, v] : llvm::enumerate(tilableOp->getResults())) {
+ if (failed(tilableOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultOffsets[idx], resultSizes[idx]))) {
+ return rewriter.notifyMatchFailure(
+ tilableOp,
+ "can't get result domain position from iter domain position");
+ }
+ }
+ allResultOffsets = resultOffsets;
+ allResultSizes = resultSizes;
+ return success();
+}
+
+/// Considering multi-level tensor.*SliceOp maybe based on different
+/// coordination, this utility computes the real OFFSET coordinated on ROOT
+/// SliceOp. E.g
+/// %0 = insert_slice %1 into %2[OFFSET1] [SIZE1]
+/// %3 = insert_slice %4 into %5[OFFSET2] [SIZE2]
+///
+/// where the coordination can be illustrated as follow:
+///
+/// %3 ----------------------------------
+/// | | |
+/// | OFFSET2 | OFFSET1 |
+/// | ------ %0 |
+/// | |
+/// | |
+/// |------------------ %1 ------ |
+/// | | SIZE1 |
+/// | | |
+/// | | |
+/// | | ------- |
+/// |
+///
+/// The real OFFSET of %1 coordinated on %3 is actually `OFFSET1` + `OFFSET2`
+static FailureOr<SmallVector<OpFoldResult>>
+computeRealOffsetsCoordinatedRootSliceOp(
+ RewriterBase &rewriter, Location loc,
+ OffsetSizeAndStrideOpInterface candidateSliceOp,
+ MutableArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceOpList) {
+ if (llvm::any_of(candidateSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "candidateSliceOp has stride");
+ }
+ SmallVector<OpFoldResult> realOffsets = candidateSliceOp.getMixedOffsets();
+ // Real offsets equals to accumulative offsets of outer candidates
+ for (auto iter = candidateSliceOpList.rbegin(); *iter != candidateSliceOp;
+ iter++) {
+ // assert each outer candidate slice has no stride
+ if (llvm::any_of(iter->getMixedStrides(), [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return failure();
+ }
+ for (auto &&[ofr1, ofr2] :
+ llvm::zip_equal(realOffsets, iter->getMixedOffsets())) {
+ using AVE = affine::AffineValueExpr;
+ affine::AffineBuilder ab(rewriter, loc);
+ AffineExpr dim0, dim1, sym;
+ bindDims(rewriter.getContext(), dim0, dim1);
+ bindSymbols(rewriter.getContext(), sym);
+ auto aveOffset1 = AVE(dim0).bind(ofr1), aveOffset2 = AVE(dim1).bind(ofr2);
+ ofr1 = ab.add(aveOffset1, aveOffset2);
+ }
+ }
+ return realOffsets;
+}
+
+/// Get the first tilable user of given Value and check its domination at the
+/// same time
+static FailureOr<OpOperand *>
+getTilableConsumerOperandFirstUseVal(Value val, Operation *loopOp) {
+ for (auto &useOfval : val.getUses()) {
+ Operation *consumerOp = useOfval.getOwner();
+ // 1. Check whether consumerOp is tilable
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp))
+ continue;
+ // 2. Check stay in same block with loopOp
+ if (loopOp->getBlock() != consumerOp->getBlock())
+ continue;
+ // 3. Check no other user before it
+ if (failed(checkAssumptionForLoop(loopOp, consumerOp))) {
+ continue;
+ }
+ return &useOfval;
+ }
+ return failure();
+}
+
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1391,10 +1555,28 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
- // 1. Get the consumer of scf.for for the result yielded by
- // tensor.insert_slice/parallel_insert_slice.
+ // 1.a. Get the real consumer of candidate
+ // tensor.insert_slice/parallel_insert_slice by walking through
+ // scf.for/scf.forall and collect all [Parallel]insertSliceOp(s) along the
+ // way.
+ FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+ resultAndSliceOpsPair =
+ getResultOfTopLevelLoopYieldInsertSliceOp(candidateSliceOp);
+ if (failed(resultAndSliceOpsPair)) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "could not fetch consumer to fuse");
+ }
+
+ // 1.b. Get all outer loops of candidateSliceOp.
+ SmallVector<LoopLikeOpInterface> outerLoops = getOuterLoopsOfSliceOp(
+ candidateSliceOp, dyn_cast<OpResult>((*resultAndSliceOpsPair).first)
+ .getDefiningOp<LoopLikeOpInterface>());
+ LoopLikeOpInterface outerMostLoop = outerLoops.front();
+
+ // 2. Get first tilable consumer op
FailureOr<OpOperand *> maybeConsumerOpOperand =
- getUntiledConsumerFromSlice(candidateSliceOp);
+ getTilableConsumerOperandFirstUseVal((*resultAndSliceOpsPair).first,
+ outerMostLoop);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSliceOp,
"could not fetch consumer to fuse");
@@ -1410,111 +1592,193 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}
- Operation *oldLoopOp = nullptr;
- SmallVector<Value> newOuts;
- Block *oldLoopBody = nullptr;
- unsigned initSize = 0;
- unsigned rank = 1;
- if (isInsertSliceOp) {
- auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
- oldLoopOp = forOp;
- llvm::append_range(newOuts, forOp.getInits());
- oldLoopBody = forOp.getBody();
- initSize = forOp.getInits().size();
- } else {
- auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
- oldLoopOp = forallOp;
- llvm::append_range(newOuts, forallOp.getOutputs());
- oldLoopBody = forallOp.getBody();
- initSize = forallOp.getOutputs().size();
- rank = forallOp.getRank();
- }
-
- if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
- return rewriter.notifyMatchFailure(
- oldLoopOp, "containing loop op should either yield just one value or "
- "have the consumer op as its first user");
- }
-
- OpBuilder::InsertionGuard g(rewriter);
-
- // 2. Check consumer is not using scf loop's output as init.
+ // 3. Check consumer is not using outerMostLoop's output as init.
auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
SmallVector<Value> dpsInits =
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+ if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
- newOuts.append(dpsInits);
- Location loc = oldLoopOp->getLoc();
+ // 4.a. Reconstruct nested loop from outer to inner.
+ SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList =
+ (*resultAndSliceOpsPair).second;
+ SmallVector<LoopLikeOpInterface> newOuterLoops;
+ SmallVector<SmallVector<OpFoldResult>> allResultOffsets, allResultSizes;
+ SmallVector<tensor::ExtractSliceOp> newExtractOps;
- // 3. Create new scf loop op.
- rewriter.setInsertionPoint(consumerOp);
- Operation *newLoopOp = nullptr;
+ Block *oldLoopBody = nullptr;
Block *newLoopBody = nullptr;
+ SmallVector<Value> newInitAppend = dpsInits, newOuts;
+
+ OpBuilder::InsertionGuard g(rewriter);
+ // 4.b. Set insertPoint right before consumerOp
+ rewriter.setInsertionPoint(consumerOp);
+
+ for (auto [index, loop] :
+ llvm::enumerate(MutableArrayRef(outerLoops).drop_back())) {
+ if (index > 0) {
+ rewriter.setInsertionPoint(loop);
+ // 4.c. Create `extractSliceOp` for newInits if they comes from sharedOut
+ // of previous `scf.forall` loop.
+ if (auto prevOuterLoop =
+ dyn_cast<scf::ForallOp>(newOuterLoops.back().getOperation())) {
+ if (index != 1) {
+ return rewriter.notifyMatchFailure(
+ prevOuterLoop, "Currently only outerMostLoop assumed forallOp");
+ }
+ OffsetSizeAndStrideOpInterface outerMostCandidate =
+ candidateSliceOpList.back();
+ if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+ rewriter, cast<TilingInterface>(consumerOp), operandNumber,
+ outerMostCandidate, allResultOffsets, allResultSizes))) {
+ return failure();
+ }
+ newExtractOps = fixLoopInitFromSharedOutSCFForall(
+ rewriter, loop, newInitAppend, allResultOffsets, allResultSizes);
+ newInitAppend = llvm::map_to_vector(
+ newExtractOps,
+ [](tensor::ExtractSliceOp op) -> Value { return op.getResult(); });
+ }
+ }
+ LoopLikeOpInterface newLoopOp;
+ // 4.d. Create a new loop with the new init values for this loop.
+ if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
+ newOuts = llvm::to_vector(forOp.getInits());
+ newOuts.append(newInitAppend.begin(), newInitAppend.end());
+ auto newLoop = rewriter.create<scf::ForOp>(
+ forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+ forOp.getStep(), newOuts);
+ newLoopOp = newLoop;
+ oldLoopBody = forOp.getBody();
+ newLoopBody = newLoop.getBody();
+ } else if (auto forallOp = dyn_cast<scf::ForallOp>(loop.getOperation())) {
+ newOuts = llvm::to_vector(forallOp.getOutputs());
+ newOuts.append(newInitAppend.begin(), newInitAppend.end());
+ auto newLoop = rewriter.create<scf::ForallOp>(
+ forallOp.getLoc(), forallOp.getMixedLowerBound(),
+ forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
+ forallOp.getMapping());
+ rewriter.eraseOp(newLoop.getTerminator());
+ newLoopOp = newLoop;
+ oldLoopBody = forallOp.getBody();
+ newLoopBody = newLoop.getBody();
+ }
+ newInitAppend = llvm::map_to_vector(
+ newLoopBody->getArguments().take_back(newInitAppend.size()),
+ [](BlockArgument bArg) -> Value { return bArg; });
+ rewriter.mergeBlocks(
+ oldLoopBody, newLoopBody,
+ newLoopBody->getArguments().take_front(oldLoopBody->getNumArguments()));
+ rewriter.replaceOp(
+ loop, newLoopOp->getResults().take_front(loop->getNumResults()));
+ newOuterLoops.push_back(newLoopOp);
+ }
+
+ // 5.a. Reconstruct inner-most loop.
+ LoopLikeOpInterface oldInnerMostLoop = outerLoops.back(), newInnerMostLoop;
+ Location loc = oldInnerMostLoop->getLoc();
+ if (outerLoops.size() > 1)
+ rewriter.setInsertionPoint(oldInnerMostLoop);
+
if (isInsertSliceOp) {
- auto forOp = cast<scf::ForOp>(oldLoopOp);
+ auto forOp = cast<scf::ForOp>(oldInnerMostLoop.getOperation());
+ newOuts = llvm::to_vector(forOp.getInits());
+ newOuts.append(newInitAppend.begin(), newInitAppend.end());
+ oldLoopBody = forOp.getBody();
auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(), newOuts);
- newLoopOp = newForOp;
+ newInnerMostLoop = newForOp;
newLoopBody = newForOp.getBody();
} else {
- auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+ auto forallOp = cast<scf::ForallOp>(oldInnerMostLoop.getOperation());
+ newOuts = llvm::to_vector(forallOp.getOutputs());
+ newOuts.append(newInitAppend.begin(), newInitAppend.end());
+ oldLoopBody = forallOp.getBody();
auto newForallOp = rewriter.create<scf::ForallOp>(
loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
forallOp.getMixedStep(), newOuts, forallOp.getMapping());
- newLoopOp = newForallOp;
+ newInnerMostLoop = newForallOp;
rewriter.eraseOp(newForallOp.getTerminator());
newLoopBody = newForallOp.getBody();
}
- // 4. Move the loop body to the new op.
+ // 5.b. Move the loop body to the new op.
unsigned oldNumArguments = oldLoopBody->getNumArguments();
rewriter.mergeBlocks(oldLoopBody, newLoopBody,
newLoopBody->getArguments().take_front(oldNumArguments));
+ // 5.c. Replace the result of old oldInnerMostLoop with newInnerMostLoop's
+ // results.
+ rewriter.replaceOp(oldInnerMostLoop,
+ newInnerMostLoop->getResults().take_front(
+ oldInnerMostLoop->getNumResults()));
- // 5. Set insertion point before terminator op of the loop and create a new
+ // 6. Set insertion point before terminator op of the loop and create a new
// tensor.insert_slice. In the scf.for case this is a clone of the
// candidateSliceOp whereas in the scf.forall case this is created from the
// operands of tensor.parallel_insert_slice.
tensor::InsertSliceOp clonedInsertSliceOp;
if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
- auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+ auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
rewriter.setInsertionPoint(newForallOp.getTerminator());
- clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
} else {
rewriter.setInsertionPoint(candidateSliceOp);
- clonedInsertSliceOp =
- cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
}
- // 6.a. Clone consumer op.
- auto newForOpBlockArgsForConsumerDest =
- newLoopBody->getArguments().drop_front(oldNumArguments);
+ // 7.a. Due to current assumption of `getTiledImplementation` that all
+ // `Operands` are untiled with original tensor size, create dummy
+ // `insertSliceOp` to align with that requirement.
+ auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp);
+ FailureOr<SmallVector<OpFoldResult>> realOffsets =
+ computeRealOffsetsCoordinatedRootSliceOp(rewriter, loc, ossSliceOp,
+ candidateSliceOpList);
+ if (failed(realOffsets))
+ return failure();
+ clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, candidateSliceOp->getOperand(0),
+ candidateSliceOpList.back()->getOperand(1), *realOffsets,
+ ossSliceOp.getMixedSizes(), ossSliceOp.getMixedStrides());
+
+ SmallVector<Value> newDpsInitsForConsumerDest = llvm::map_to_vector(
+ newLoopBody->getArguments().drop_front(oldNumArguments),
+ [](BlockArgument bArg) -> Value { return bArg; });
+
+ // 7.b. If the outerMostLoop is scf.forall, then the `newExtractOps` has been
+ // additionally created at `step 4.e` for `dpsInits`. As the counterpart, the
+ // `insertSliceOp` is also needed for the same purpose with `step 7.a`.
+ if (!newExtractOps.empty()) {
+ for (auto &&[extractOp, newDpsInit] :
+ llvm::zip_equal(newExtractOps, newDpsInitsForConsumerDest)) {
+ auto alignDpsInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, newDpsInit, extractOp.getSource(), extractOp.getMixedOffsets(),
+ extractOp.getMixedSizes(), extractOp.getMixedStrides());
+ newDpsInit = alignDpsInsertSliceOp.getResult();
+ }
+ }
+
+ // 8.a. Clone consumer op.
auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
- rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+ rewriter, consumerOp, newDpsInitsForConsumerDest));
- // 6.b. Replace all uses of the loop result with the result of the cloned
+ // 8.b. Replace all uses of the loop result with the result of the cloned
// tensor.insert_slice.
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
operandToReplace.set(clonedInsertSliceOp.getResult());
});
- // 7 - Perform tiling of the cloned consumer and replace the operand at
+ // 9. Perform tiling of the cloned consumer and replace the operand at
// `operandNumber` with the source of the cloned tensor.insert_slice op.
- auto ossSliceOp =
- cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceInsertSliceWithTiledConsumer(
- rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+ rewriter,
+ cast<OffsetSizeAndStrideOpInterface>(
+ clonedInsertSliceOp.getOperation()),
+ clonedConsumerOp->getOpOperand(operandNumber));
if (failed(tileAndFuseResult)) {
return failure();
}
@@ -1522,75 +1786,100 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
clonedInsertSliceOp.getSource());
- // 8 - Extract offset/sizes/strides required to create the
- // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
- SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
-
- // 9. Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return rewriter.notifyMatchFailure(
- candidateSliceOp, "containingOp's result yield with stride");
- }
-
- // 10. Try to get iter domain position from input position.
- SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
- if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
- rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
- iterDomainSizes))) {
- return rewriter.notifyMatchFailure(
- clonedConsumerOp, "can't get iter domain position from input position");
- }
-
- // 11. Try to fetch the offset and size for all results of the cloned
+ // 10. Try to fetch the offset and size for all results of the cloned
// consumer. This would then be used to form the corresponding
// tensor.insert_slice/parallel_insert_slice later.
- unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
- SmallVector<SmallVector<OpFoldResult>> resultOffsets(
- totalNumResultsOfConsumer);
- SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
- for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
- if (failed(clonedConsumerOp.getResultTilePosition(
- rewriter, idx, iterDomainOffsets, iterDomainSizes,
- resultOffsets[idx], resultSizes[idx]))) {
- return rewriter.notifyMatchFailure(
- clonedConsumerOp,
- "can't get result domain position from iter domain position");
- }
+ if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+ rewriter, clonedConsumerOp, operandNumber, ossSliceOp,
+ allResultOffsets, allResultSizes))) {
+ return failure();
+ }
+
+ if (!newExtractOps.empty()) {
+ fixDpsInitsOfTiledConsumer(
+ rewriter, tileAndFuseResult->tiledOps[0],
+ newLoopBody->getArguments().drop_front(oldNumArguments),
+ allResultOffsets, allResultSizes);
}
- auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
- auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
if (isInsertSliceOp) {
- auto newForOp = cast<scf::ForOp>(newLoopOp);
+ auto newForOp = cast<scf::ForOp>(newInnerMostLoop);
fixTerminatorSCFYield(
- rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
- newForOp.getBody()->getArguments().drop_front(1 + initSize));
+ rewriter, newForOp, tileAndFuseResult->tiledOps[0]->getResults(),
+ allResultOffsets, allResultSizes,
+ newForOp.getBody()->getArguments().take_back(newInitAppend.size()));
} else {
- auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+ auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
fixTerminatorSCFInParallel(
rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
- arrayRefOffsets, arrayRefSizes,
- newForallOp.getBody()->getArguments().drop_front(rank + initSize));
+ allResultOffsets, allResultSizes,
+ newForallOp.getBody()->getArguments().take_back(newInitAppend.size()));
}
- // 12. Replace the result of scf loop and consumer op with new loop's results.
- for (auto &&[oldResult, newResult] :
- llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
- rewriter.replaceAllUsesWith(oldResult, newResult);
+ newOuterLoops.push_back(cast<LoopLikeOpInterface>(newInnerMostLoop));
+
+ // 11.a. Reconstruct terminator of outer loop by inner loop.
+ auto outerCandidateIter = candidateSliceOpList.rbegin();
+ for (auto [outerLoop, innerLoop] :
+ llvm::zip_equal(MutableArrayRef(newOuterLoops).drop_back(),
+ MutableArrayRef(newOuterLoops).drop_front())) {
+ // 11.b. Create insertSliceOp according outer candidateSliceOp
+ if (outerCandidateIter != candidateSliceOpList.rend() &&
+ outerCandidateIter->getOperation()
+ ->getParentOfType<LoopLikeOpInterface>() == outerLoop) {
+ if (auto forallOp = dyn_cast<scf::ForallOp>(outerLoop.getOperation())) {
+ rewriter.setInsertionPoint(forallOp.getTerminator());
+ } else {
+ rewriter.setInsertionPointAfter(*outerCandidateIter);
+ }
+
+ if (failed(computeAllResultTileForOpGivenOperandSliceOp(
+ rewriter, clonedConsumerOp, operandNumber, *outerCandidateIter,
+ allResultOffsets, allResultSizes))) {
+ return failure();
+ }
+
+ if (auto forOp = dyn_cast<scf::ForOp>(outerLoop.getOperation())) {
+ fixTerminatorSCFYield(
+ rewriter, forOp,
+ innerLoop->getResults().take_back(newInitAppend.size()),
+ allResultOffsets, allResultSizes,
+ forOp.getBody()->getArguments().take_back(newInitAppend.size()));
+ } else if (auto forallOp =
+ dyn_cast<scf::ForallOp>(outerLoop.getOperation())) {
+ fixTerminatorSCFInParallel(
+ rewriter, forallOp,
+ innerLoop->getResults().take_back(newInitAppend.size()),
+ allResultOffsets, allResultSizes,
+ forallOp.getBody()->getArguments().take_back(newInitAppend.size()));
+ }
+ outerCandidateIter++;
+ } else {
+ // 11.c. Yield additional new appended results of innerLoop
+ assert(isa<scf::ForOp>(outerLoop));
+ auto forOp = cast<scf::ForOp>(outerLoop);
+ auto outerLoopYield =
+ cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+ SmallVector<Value> newYields =
+ llvm::to_vector(outerLoopYield.getOperands());
+ ValueRange additionalYields =
+ innerLoop->getResults().take_back(newInitAppend.size());
+ newYields.append(additionalYields.begin(), additionalYields.end());
+ rewriter.setInsertionPoint(outerLoopYield);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
+ }
}
+ // 12. Replace the result of consumer op with new outerMost loop's
+ // results.
for (auto &&[oldResult, newResult] :
llvm::zip(consumerOp->getResults(),
- newLoopOp->getResults().drop_front(initSize))) {
+ newOuterLoops.front()->getResults().take_back(
+ newInitAppend.size()))) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
- // 13. Need to erase the old scf loop and the cloned consumer op.
- rewriter.eraseOp(oldLoopOp);
+ // 13. Need to erase the cloned consumer op.
rewriter.eraseOp(clonedConsumerOp);
return scf::SCFFuseConsumerOfSliceResult{
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 400b558e37fcd..d60d5f4fd7b3c 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -315,3 +315,99 @@ module attributes {transform.with_named_sequence} {
// CHECK: }
// CHECK: }
// CHECK: return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 128)>
+module {
+ func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dest0 = tensor.empty() : tensor<256x256xf32>
+ %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
+ %iv0 = affine.apply #map(%arg3)
+ %iv1 = affine.apply #map(%arg4)
+ %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
+ %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
+ %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
+ %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
+ %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
+ %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
+ %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
+ %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
+ %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
+ %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
+ scf.yield %insert_slice : tensor<128x128xf32>
+ }
+ scf.yield %3 : tensor<128x128xf32>
+ }
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
+ }
+ }
+ %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ return %5 : tensor<256x256xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %slice_op
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[MAP0:.*]] = affine_map<(d0) -> (d0 * 128)>
+// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 128)>
+// CHECK: func.func @fuse_tilable_consumer_nested_scf_loop(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+// CHECK: %[[dest1:.*]] = linalg.fill
+// CHECK-SAME: outs(%[[dest0]] :
+// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG0:.*]] = %[[dest1]], %[[SECOND_OUT_ARG0:.*]] = %[[dest0]])
+// CHECK-SAME: {
+// CHECK: %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
+// CHECK: %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
+// CHECK: %[[MAT_OUT_SLICE0:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+// CHECK: %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
+// CHECK: %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
+// CHECK: %[[ADD_OUT_SLICE0:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+// CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV3:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[MAT_OUT_SLICE0]], %[[SECOND_OUT_ARG1:.*]] = %[[ADD_OUT_SLICE0]])
+// CHECK-SAME: {
+// CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV4:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE1:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+// CHECK: %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
+// CHECK: %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
+// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE1]] :
+// CHECK: %[[REAL_SECOND_IV1:.*]] = affine.apply #[[MAP1]](%[[IV3]], %[[IV1]])
+// CHECK: %[[REAL_SECOND_IV2:.*]] = affine.apply #[[MAP1]](%[[IV4]], %[[IV2]])
+// CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[REAL_SECOND_IV1]], %[[REAL_SECOND_IV2]]] [64, 64] [1, 1]
+// CHECK: %[[ADD_OUT_SLICE1:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
+// CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
+// CHECK-SAME: outs(%[[ADD_OUT_SLICE1]] :
+// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+// CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
+// CHECK: }
+// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
+// CHECK: }
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]]#1 into %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]]#0 into %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#1 :
>From a841446cb1258a465c9bff37beb31084b0d2eb77 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Mon, 24 Jun 2024 18:22:21 -0700
Subject: [PATCH 2/3] Revert "extend consumer fuse to nested scf loop (v1)"
This reverts commit 5ba402530cd9909ae76136d8d253e4897987dfb7.
---
.../SCF/Transforms/TileUsingInterface.cpp | 731 ++++++------------
.../tile-and-fuse-consumer.mlir | 96 ---
2 files changed, 221 insertions(+), 606 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 8a7eac1d883f9..a1392813d6de3 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -13,7 +13,6 @@
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -1198,6 +1197,98 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
// tileAndFuseConsumerUsingSCF implementation.
//===----------------------------------------------------------------------===//
+/// A utility function that checks whether the only use of the result of a
+/// tensor.insert_slice op is in a scf.yield op.
+static LogicalResult
+checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
+ Value result = candidateSliceOp.getResult();
+ Value::use_range uses = result.getUses();
+ if (!llvm::hasSingleElement(uses)) {
+ LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+ return failure();
+ }
+ OpOperand &operandUse = (*uses.begin());
+ Operation *userOp = operandUse.getOwner();
+ if (!isa<scf::YieldOp>(userOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Expected scf.yield to be the only user, but got -> "
+ << (*userOp));
+ return failure();
+ }
+ if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
+ LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
+ "be in the same block\n");
+ return failure();
+ }
+ return success();
+}
+
+/// Fetches the OpOperand of the only user (and use) of the value `val` which
+/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
+/// failure otherwise.
+static FailureOr<OpOperand *> getConsumerFromUses(Value val,
+ Block *containingOpBlock) {
+ // Step 1. Check that the value has exactly one use.
+ if (!llvm::hasSingleElement(val.getUses()))
+ return failure();
+ // Step 2. Get uses.
+ OpOperand &operand = (*val.getUses().begin());
+ Operation *consumerOp = operand.getOwner();
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp))
+ return failure();
+ if (containingOpBlock != consumerOp->getBlock())
+ return failure();
+ return &operand;
+}
+
+/// Fetch the untiled consumer of a scf.for's result which is yielded by a
+/// tensor.insert_slice. This function makes the following assumptions :
+/// 1. tensor.insert_slice has scf.yield as its only user.
+/// 2. scf.for's corresponding result has only one use.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+ if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
+ return failure();
+ Value sliceResult = candidateSliceOp.getResult();
+ // Step 1. Fetch the corresponding output.
+ OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+ unsigned resultNumber = yieldOpOperand.getOperandNumber();
+ // Step 2. Check containing op is scf.for.
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ auto forOp = dyn_cast<scf::ForOp>(containingOp);
+ if (!forOp)
+ return failure();
+ Value resultingValue = forOp->getResult(resultNumber);
+
+ return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // Step 1. Fetch the corresponding output
+ Value sliceDest = candidateSliceOp.getDest();
+ auto iterArg = dyn_cast<BlockArgument>(sliceDest);
+ if (!iterArg)
+ return failure();
+ Operation *containingOp = iterArg.getOwner()->getParentOp();
+ if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
+ return failure();
+ // Step 2. Check that the containing op is scf.forall.
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp)
+ return failure();
+ Value resultingValue =
+ forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+
+ return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
/// This utility currently checks whether the loop either :-
/// 1. Yields exactly one result.
/// 2. Has consumer op as its first user and other users to be in the same
@@ -1223,117 +1314,31 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
return success();
}
-/// Traverse and collect all outer loops of given sliceOp, sorted by
-/// outer-to-inner. If `untilLoop` found, stop walk through in advance.
-static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
- Operation *sliceOp,
- std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
- assert(isa<OffsetSizeAndStrideOpInterface>(sliceOp));
- SmallVector<LoopLikeOpInterface> outerLoops;
- auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
- while (forOp) {
- outerLoops.push_back(forOp);
- if (untilLoop.has_value() && *untilLoop == forOp)
- break;
- forOp = forOp->getParentOfType<LoopLikeOpInterface>();
- }
- return {outerLoops.rbegin(), outerLoops.rend()};
-}
-
-/// Get the result of top-level loop which yields the target InsertSliceOp. E.g
-/// ```
-/// %1 = scf.for
-/// %2 = scf.for
-/// %3 = scf.for
-/// ...
-/// %4 = insert
-/// yield %4
-/// %5 = insert %3
-/// yield %5
-/// yield %2
-/// ```
-/// @param targetSliceOp: %4 = insert
-/// @return first: resultValue: %1
-/// second: Collected insertSliceOp List during walk including
-/// targetSliceOp: %4 = insert and %5 = insert %3
-static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
-getResultOfTopLevelLoopYieldInsertSliceOp(Operation *targetSliceOp,
- int curDepth = 0, int maxDepth = 5) {
- assert(isa<OffsetSizeAndStrideOpInterface>(targetSliceOp));
- // Control recursive time in avoid of stack overflow
- if (curDepth > maxDepth)
- return failure();
-
- SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
- candidateSliceOpList.push_back(
- cast<OffsetSizeAndStrideOpInterface>(targetSliceOp));
- Value resultOfLoop;
- if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(targetSliceOp)) {
- Value destValue = sliceOp.getDest();
- auto iterArg = cast<BlockArgument>(destValue);
- auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
- if (!forallOp)
- return failure();
- resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
- } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(targetSliceOp)) {
- Value resultValue = sliceOp.getResult();
- for (auto &useOperand : resultValue.getUses()) {
- if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
- if (llvm::detail::isPresent(resultOfLoop))
- return failure();
- auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
- if (!forOp)
- return failure();
- resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
- }
- }
- }
-
- if (!llvm::detail::isPresent(resultOfLoop))
+/// A utility to fetch an untiled consumer of
+/// tensor.insert_slice/tensor.parallel_insert_slice.
+static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
+ if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
+ return getUntiledConsumerFromSlice(insertSlice);
+ } else if (auto parallelInsertSlice =
+ dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
+ return getUntiledConsumerFromSlice(parallelInsertSlice);
+ } else {
return failure();
-
- while (true) {
- bool walkThroughOuterLoop = false;
- for (OpOperand &useOperand : resultOfLoop.getUses()) {
- if (auto sliceOp =
- dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
- FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
- resultAndSliceOpsPair = getResultOfTopLevelLoopYieldInsertSliceOp(
- sliceOp, curDepth + 1);
- if (failed(resultAndSliceOpsPair))
- return failure();
- candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
- (*resultAndSliceOpsPair).second.end());
- return std::make_pair((*resultAndSliceOpsPair).first,
- candidateSliceOpList);
- } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
- // walk through outer loop
- auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
- if (!forOp)
- return failure();
- resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
- walkThroughOuterLoop = true;
- break;
- }
- }
- if (!walkThroughOuterLoop)
- break;
}
- return std::make_pair(resultOfLoop, candidateSliceOpList);
}
/// After fusing consumer into scf.for we want to modify the scf.yield operation
/// to reflect the same by returning the values yielded by the tiled consumer.
static void
fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
- ResultRange tilingResult,
- ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> resultSizes,
+ TilingResult &tilingResult,
+ ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
ArrayRef<BlockArgument> bbArgs) {
scf::YieldOp oldTerminatorOp =
cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
unsigned totalOldResults = oldTerminatorOp->getNumResults();
- unsigned totalTiledResults = tilingResult.size();
+ unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(totalOldResults + totalTiledResults);
for (auto oldResult : oldTerminatorOp.getResults()) {
@@ -1342,7 +1347,8 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
rewriter.setInsertionPointAfter(oldTerminatorOp);
Location loc = newForOp.getLoc();
for (auto [tiledResult, bbArg, resultOffset, resultSize] :
- llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
+ llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
+ resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> strides(resultOffset.size(),
rewriter.getIndexAttr(1));
Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
@@ -1357,16 +1363,16 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
/// values by the tiled consumer within scf.forall.in_parallel region.
static void
fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
- ResultRange tilingResult,
- ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> resultSizes,
+ SmallVector<Value> tiledResults,
+ ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
ArrayRef<BlockArgument> bbArgs) {
scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
Location firstYieldOpLoc =
(*(newTerminatorOp.getYieldingOps().begin())).getLoc();
for (auto [tiledResult, bbArg, resultOffset, resultSize] :
- llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
+ llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> strides(resultOffset.size(),
rewriter.getIndexAttr(1));
rewriter.create<tensor::ParallelInsertSliceOp>(
@@ -1374,176 +1380,6 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
}
}
-/// If the top level loop of nested loop structure is scf.forall, need to create
-/// additional tensor.extract_slice for its new appended `shared_outs` in order
-/// to pass correct local memory for inner loops. E.g.
-///
-/// scf.forall shared_outs(%o1=..., %o2=...) {
-/// %local_o1 = extract_slice %o1
-/// // fix new appended `shared_out` %o2
-/// %local_o2 = extract_slice %o2
-/// scf.for init_args(%init1=%local_o1, %init2=%local_o2) {
-/// ...
-/// }
-/// ...
-/// }
-static SmallVector<tensor::ExtractSliceOp> fixLoopInitFromSharedOutSCFForall(
- RewriterBase &rewriter, Operation *loop, ArrayRef<Value> newSharedOuts,
- ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
- rewriter.setInsertionPoint(loop);
- Location loc = loop->getLoc();
- // create new ExtractOps for newInits from scf.forall
- SmallVector<tensor::ExtractSliceOp> newExtractOps;
- newExtractOps.reserve(resultOffsets.size());
- for (auto [bbArg, offset, sizes] :
- llvm::zip_equal(newSharedOuts, resultOffsets, resultSizes)) {
- SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
- auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
- loc, bbArg, offset, sizes, strides);
- newExtractOps.push_back(newExtractOp);
- }
- return newExtractOps;
-}
-
-/// If outerMost loop of nested loop structure is `scf.forall`, need to deal
-/// with DpsInit of tiled consumer
-static void
-fixDpsInitsOfTiledConsumer(RewriterBase &rewriter, Operation *tiledConsumer,
- ArrayRef<BlockArgument> bbArgs,
- ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
- rewriter.setInsertionPoint(tiledConsumer);
- Location loc = tiledConsumer->getLoc();
- for (auto &&[bbArg, offset, sizes, dpsInit] :
- llvm::zip_equal(bbArgs, resultOffsets, resultSizes,
- cast<DestinationStyleOpInterface>(tiledConsumer)
- .getDpsInitsMutable())) {
- SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
- auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
- loc, bbArg, offset, sizes, strides);
- dpsInit.set(newExtractOp.getResult());
- }
-}
-
-/// Compute all results tile by given SliceOp along operand
-static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
- RewriterBase &rewriter, TilingInterface tilableOp, unsigned operandNumber,
- OffsetSizeAndStrideOpInterface ossSliceOp,
- SmallVector<SmallVector<OpFoldResult>> &allResultOffsets,
- SmallVector<SmallVector<OpFoldResult>> &allResultSizes) {
- // 1. Check all stride all 1
- if (llvm::any_of(ossSliceOp.getMixedStrides(), [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return rewriter.notifyMatchFailure(ossSliceOp, "ossSliceOp has stride");
- }
- // 2. Compute iteration domain tile from the input position
- SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
- if (failed(tilableOp.getIterationDomainTileFromOperandTile(
- rewriter, operandNumber, ossSliceOp.getMixedOffsets(),
- ossSliceOp.getMixedSizes(), iterDomainOffsets, iterDomainSizes))) {
- return rewriter.notifyMatchFailure(
- tilableOp, "can't get iter domain position from input position");
- }
- unsigned totalNumResultsOfConsumer = tilableOp->getNumResults();
- SmallVector<SmallVector<OpFoldResult>> resultOffsets(
- totalNumResultsOfConsumer);
- SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
- // 3. Compute result Tile by resultNumber
- for (auto [idx, v] : llvm::enumerate(tilableOp->getResults())) {
- if (failed(tilableOp.getResultTilePosition(
- rewriter, idx, iterDomainOffsets, iterDomainSizes,
- resultOffsets[idx], resultSizes[idx]))) {
- return rewriter.notifyMatchFailure(
- tilableOp,
- "can't get result domain position from iter domain position");
- }
- }
- allResultOffsets = resultOffsets;
- allResultSizes = resultSizes;
- return success();
-}
-
-/// Considering multi-level tensor.*SliceOp maybe based on different
-/// coordination, this utility computes the real OFFSET coordinated on ROOT
-/// SliceOp. E.g
-/// %0 = insert_slice %1 into %2[OFFSET1] [SIZE1]
-/// %3 = insert_slice %4 into %5[OFFSET2] [SIZE2]
-///
-/// where the coordination can be illustrated as follow:
-///
-/// %3 ----------------------------------
-/// | | |
-/// | OFFSET2 | OFFSET1 |
-/// | ------ %0 |
-/// | |
-/// | |
-/// |------------------ %1 ------ |
-/// | | SIZE1 |
-/// | | |
-/// | | |
-/// | | ------- |
-/// |
-///
-/// The real OFFSET of %1 coordinated on %3 is actually `OFFSET1` + `OFFSET2`
-static FailureOr<SmallVector<OpFoldResult>>
-computeRealOffsetsCoordinatedRootSliceOp(
- RewriterBase &rewriter, Location loc,
- OffsetSizeAndStrideOpInterface candidateSliceOp,
- MutableArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceOpList) {
- if (llvm::any_of(candidateSliceOp.getMixedStrides(), [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return rewriter.notifyMatchFailure(candidateSliceOp,
- "candidateSliceOp has stride");
- }
- SmallVector<OpFoldResult> realOffsets = candidateSliceOp.getMixedOffsets();
- // Real offsets equals to accumulative offsets of outer candidates
- for (auto iter = candidateSliceOpList.rbegin(); *iter != candidateSliceOp;
- iter++) {
- // assert each outer candidate slice has no stride
- if (llvm::any_of(iter->getMixedStrides(), [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return failure();
- }
- for (auto &&[ofr1, ofr2] :
- llvm::zip_equal(realOffsets, iter->getMixedOffsets())) {
- using AVE = affine::AffineValueExpr;
- affine::AffineBuilder ab(rewriter, loc);
- AffineExpr dim0, dim1, sym;
- bindDims(rewriter.getContext(), dim0, dim1);
- bindSymbols(rewriter.getContext(), sym);
- auto aveOffset1 = AVE(dim0).bind(ofr1), aveOffset2 = AVE(dim1).bind(ofr2);
- ofr1 = ab.add(aveOffset1, aveOffset2);
- }
- }
- return realOffsets;
-}
-
-/// Get the first tilable user of given Value and check its domination at the
-/// same time
-static FailureOr<OpOperand *>
-getTilableConsumerOperandFirstUseVal(Value val, Operation *loopOp) {
- for (auto &useOfval : val.getUses()) {
- Operation *consumerOp = useOfval.getOwner();
- // 1. Check whether consumerOp is tilable
- if (!isa<TilingInterface>(consumerOp) ||
- !isa<DestinationStyleOpInterface>(consumerOp))
- continue;
- // 2. Check stay in same block with loopOp
- if (loopOp->getBlock() != consumerOp->getBlock())
- continue;
- // 3. Check no other user before it
- if (failed(checkAssumptionForLoop(loopOp, consumerOp))) {
- continue;
- }
- return &useOfval;
- }
- return failure();
-}
-
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1555,28 +1391,10 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
- // 1.a. Get the real consumer of candidate
- // tensor.insert_slice/parallel_insert_slice by walking through
- // scf.for/scf.forall and collect all [Parallel]insertSliceOp(s) along the
- // way.
- FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
- resultAndSliceOpsPair =
- getResultOfTopLevelLoopYieldInsertSliceOp(candidateSliceOp);
- if (failed(resultAndSliceOpsPair)) {
- return rewriter.notifyMatchFailure(candidateSliceOp,
- "could not fetch consumer to fuse");
- }
-
- // 1.b. Get all outer loops of candidateSliceOp.
- SmallVector<LoopLikeOpInterface> outerLoops = getOuterLoopsOfSliceOp(
- candidateSliceOp, dyn_cast<OpResult>((*resultAndSliceOpsPair).first)
- .getDefiningOp<LoopLikeOpInterface>());
- LoopLikeOpInterface outerMostLoop = outerLoops.front();
-
- // 2. Get first tilable consumer op
+ // 1. Get the consumer of scf.for for the result yielded by
+ // tensor.insert_slice/parallel_insert_slice.
FailureOr<OpOperand *> maybeConsumerOpOperand =
- getTilableConsumerOperandFirstUseVal((*resultAndSliceOpsPair).first,
- outerMostLoop);
+ getUntiledConsumerFromSlice(candidateSliceOp);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSliceOp,
"could not fetch consumer to fuse");
@@ -1592,193 +1410,111 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}
- // 3. Check consumer is not using outerMostLoop's output as init.
+ Operation *oldLoopOp = nullptr;
+ SmallVector<Value> newOuts;
+ Block *oldLoopBody = nullptr;
+ unsigned initSize = 0;
+ unsigned rank = 1;
+ if (isInsertSliceOp) {
+ auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+ oldLoopOp = forOp;
+ llvm::append_range(newOuts, forOp.getInits());
+ oldLoopBody = forOp.getBody();
+ initSize = forOp.getInits().size();
+ } else {
+ auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
+ oldLoopOp = forallOp;
+ llvm::append_range(newOuts, forallOp.getOutputs());
+ oldLoopBody = forallOp.getBody();
+ initSize = forallOp.getOutputs().size();
+ rank = forallOp.getRank();
+ }
+
+ if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+ return rewriter.notifyMatchFailure(
+ oldLoopOp, "containing loop op should either yield just one value or "
+ "have the consumer op as its first user");
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+
+ // 2. Check consumer is not using scf loop's output as init.
auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
SmallVector<Value> dpsInits =
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
+ if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
+ newOuts.append(dpsInits);
- // 4.a. Reconstruct nested loop from outer to inner.
- SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList =
- (*resultAndSliceOpsPair).second;
- SmallVector<LoopLikeOpInterface> newOuterLoops;
- SmallVector<SmallVector<OpFoldResult>> allResultOffsets, allResultSizes;
- SmallVector<tensor::ExtractSliceOp> newExtractOps;
-
- Block *oldLoopBody = nullptr;
- Block *newLoopBody = nullptr;
- SmallVector<Value> newInitAppend = dpsInits, newOuts;
+ Location loc = oldLoopOp->getLoc();
- OpBuilder::InsertionGuard g(rewriter);
- // 4.b. Set insertPoint right before consumerOp
+ // 3. Create new scf loop op.
rewriter.setInsertionPoint(consumerOp);
-
- for (auto [index, loop] :
- llvm::enumerate(MutableArrayRef(outerLoops).drop_back())) {
- if (index > 0) {
- rewriter.setInsertionPoint(loop);
- // 4.c. Create `extractSliceOp` for newInits if they comes from sharedOut
- // of previous `scf.forall` loop.
- if (auto prevOuterLoop =
- dyn_cast<scf::ForallOp>(newOuterLoops.back().getOperation())) {
- if (index != 1) {
- return rewriter.notifyMatchFailure(
- prevOuterLoop, "Currently only outerMostLoop assumed forallOp");
- }
- OffsetSizeAndStrideOpInterface outerMostCandidate =
- candidateSliceOpList.back();
- if (failed(computeAllResultTileForOpGivenOperandSliceOp(
- rewriter, cast<TilingInterface>(consumerOp), operandNumber,
- outerMostCandidate, allResultOffsets, allResultSizes))) {
- return failure();
- }
- newExtractOps = fixLoopInitFromSharedOutSCFForall(
- rewriter, loop, newInitAppend, allResultOffsets, allResultSizes);
- newInitAppend = llvm::map_to_vector(
- newExtractOps,
- [](tensor::ExtractSliceOp op) -> Value { return op.getResult(); });
- }
- }
- LoopLikeOpInterface newLoopOp;
- // 4.d. Create a new loop with the new init values for this loop.
- if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
- newOuts = llvm::to_vector(forOp.getInits());
- newOuts.append(newInitAppend.begin(), newInitAppend.end());
- auto newLoop = rewriter.create<scf::ForOp>(
- forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newOuts);
- newLoopOp = newLoop;
- oldLoopBody = forOp.getBody();
- newLoopBody = newLoop.getBody();
- } else if (auto forallOp = dyn_cast<scf::ForallOp>(loop.getOperation())) {
- newOuts = llvm::to_vector(forallOp.getOutputs());
- newOuts.append(newInitAppend.begin(), newInitAppend.end());
- auto newLoop = rewriter.create<scf::ForallOp>(
- forallOp.getLoc(), forallOp.getMixedLowerBound(),
- forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
- forallOp.getMapping());
- rewriter.eraseOp(newLoop.getTerminator());
- newLoopOp = newLoop;
- oldLoopBody = forallOp.getBody();
- newLoopBody = newLoop.getBody();
- }
- newInitAppend = llvm::map_to_vector(
- newLoopBody->getArguments().take_back(newInitAppend.size()),
- [](BlockArgument bArg) -> Value { return bArg; });
- rewriter.mergeBlocks(
- oldLoopBody, newLoopBody,
- newLoopBody->getArguments().take_front(oldLoopBody->getNumArguments()));
- rewriter.replaceOp(
- loop, newLoopOp->getResults().take_front(loop->getNumResults()));
- newOuterLoops.push_back(newLoopOp);
- }
-
- // 5.a. Reconstruct inner-most loop.
- LoopLikeOpInterface oldInnerMostLoop = outerLoops.back(), newInnerMostLoop;
- Location loc = oldInnerMostLoop->getLoc();
- if (outerLoops.size() > 1)
- rewriter.setInsertionPoint(oldInnerMostLoop);
-
+ Operation *newLoopOp = nullptr;
+ Block *newLoopBody = nullptr;
if (isInsertSliceOp) {
- auto forOp = cast<scf::ForOp>(oldInnerMostLoop.getOperation());
- newOuts = llvm::to_vector(forOp.getInits());
- newOuts.append(newInitAppend.begin(), newInitAppend.end());
- oldLoopBody = forOp.getBody();
+ auto forOp = cast<scf::ForOp>(oldLoopOp);
auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(), newOuts);
- newInnerMostLoop = newForOp;
+ newLoopOp = newForOp;
newLoopBody = newForOp.getBody();
} else {
- auto forallOp = cast<scf::ForallOp>(oldInnerMostLoop.getOperation());
- newOuts = llvm::to_vector(forallOp.getOutputs());
- newOuts.append(newInitAppend.begin(), newInitAppend.end());
- oldLoopBody = forallOp.getBody();
+ auto forallOp = cast<scf::ForallOp>(oldLoopOp);
auto newForallOp = rewriter.create<scf::ForallOp>(
loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
forallOp.getMixedStep(), newOuts, forallOp.getMapping());
- newInnerMostLoop = newForallOp;
+ newLoopOp = newForallOp;
rewriter.eraseOp(newForallOp.getTerminator());
newLoopBody = newForallOp.getBody();
}
- // 5.b. Move the loop body to the new op.
+ // 4. Move the loop body to the new op.
unsigned oldNumArguments = oldLoopBody->getNumArguments();
rewriter.mergeBlocks(oldLoopBody, newLoopBody,
newLoopBody->getArguments().take_front(oldNumArguments));
- // 5.c. Replace the result of old oldInnerMostLoop with newInnerMostLoop's
- // results.
- rewriter.replaceOp(oldInnerMostLoop,
- newInnerMostLoop->getResults().take_front(
- oldInnerMostLoop->getNumResults()));
- // 6. Set insertion point before terminator op of the loop and create a new
+ // 5. Set insertion point before terminator op of the loop and create a new
// tensor.insert_slice. In the scf.for case this is a clone of the
// candidateSliceOp whereas in the scf.forall case this is created from the
// operands of tensor.parallel_insert_slice.
tensor::InsertSliceOp clonedInsertSliceOp;
if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
- auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
+ auto newForallOp = cast<scf::ForallOp>(newLoopOp);
rewriter.setInsertionPoint(newForallOp.getTerminator());
+ clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
+ sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
} else {
rewriter.setInsertionPoint(candidateSliceOp);
+ clonedInsertSliceOp =
+ cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
}
- // 7.a. Due to current assumption of `getTiledImplementation` that all
- // `Operands` are untiled with original tensor size, create dummy
- // `insertSliceOp` to align with that requirement.
- auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp);
- FailureOr<SmallVector<OpFoldResult>> realOffsets =
- computeRealOffsetsCoordinatedRootSliceOp(rewriter, loc, ossSliceOp,
- candidateSliceOpList);
- if (failed(realOffsets))
- return failure();
- clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, candidateSliceOp->getOperand(0),
- candidateSliceOpList.back()->getOperand(1), *realOffsets,
- ossSliceOp.getMixedSizes(), ossSliceOp.getMixedStrides());
-
- SmallVector<Value> newDpsInitsForConsumerDest = llvm::map_to_vector(
- newLoopBody->getArguments().drop_front(oldNumArguments),
- [](BlockArgument bArg) -> Value { return bArg; });
-
- // 7.b. If the outerMostLoop is scf.forall, then the `newExtractOps` has been
- // additionally created at `step 4.e` for `dpsInits`. As the counterpart, the
- // `insertSliceOp` is also needed for the same purpose with `step 7.a`.
- if (!newExtractOps.empty()) {
- for (auto &&[extractOp, newDpsInit] :
- llvm::zip_equal(newExtractOps, newDpsInitsForConsumerDest)) {
- auto alignDpsInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, newDpsInit, extractOp.getSource(), extractOp.getMixedOffsets(),
- extractOp.getMixedSizes(), extractOp.getMixedStrides());
- newDpsInit = alignDpsInsertSliceOp.getResult();
- }
- }
-
- // 8.a. Clone consumer op.
+ // 6.a. Clone consumer op.
+ auto newForOpBlockArgsForConsumerDest =
+ newLoopBody->getArguments().drop_front(oldNumArguments);
auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
- rewriter, consumerOp, newDpsInitsForConsumerDest));
+ rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
- // 8.b. Replace all uses of the loop result with the result of the cloned
+ // 6.b. Replace all uses of the loop result with the result of the cloned
// tensor.insert_slice.
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
operandToReplace.set(clonedInsertSliceOp.getResult());
});
- // 9. Perform tiling of the cloned consumer and replace the operand at
+ // 7 - Perform tiling of the cloned consumer and replace the operand at
// `operandNumber` with the source of the cloned tensor.insert_slice op.
+ auto ossSliceOp =
+ cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceInsertSliceWithTiledConsumer(
- rewriter,
- cast<OffsetSizeAndStrideOpInterface>(
- clonedInsertSliceOp.getOperation()),
- clonedConsumerOp->getOpOperand(operandNumber));
+ rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
if (failed(tileAndFuseResult)) {
return failure();
}
@@ -1786,100 +1522,75 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
clonedInsertSliceOp.getSource());
- // 10. Try to fetch the offset and size for all results of the cloned
- // consumer. This would then be used to form the corresponding
- // tensor.insert_slice/parallel_insert_slice later.
- if (failed(computeAllResultTileForOpGivenOperandSliceOp(
- rewriter, clonedConsumerOp, operandNumber, ossSliceOp,
- allResultOffsets, allResultSizes))) {
- return failure();
+ // 8 - Extract offset/sizes/strides required to create the
+ // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
+ SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+
+ // 9. Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "containingOp's result yield with stride");
+ }
+
+ // 10. Try to get iter domain position from input position.
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ clonedConsumerOp, "can't get iter domain position from input position");
}
- if (!newExtractOps.empty()) {
- fixDpsInitsOfTiledConsumer(
- rewriter, tileAndFuseResult->tiledOps[0],
- newLoopBody->getArguments().drop_front(oldNumArguments),
- allResultOffsets, allResultSizes);
+ // 11. Try to fetch the offset and size for all results of the cloned
+ // consumer. This would then be used to form the corresponding
+ // tensor.insert_slice/parallel_insert_slice later.
+ unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+ totalNumResultsOfConsumer);
+ SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+ for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+ if (failed(clonedConsumerOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultOffsets[idx], resultSizes[idx]))) {
+ return rewriter.notifyMatchFailure(
+ clonedConsumerOp,
+ "can't get result domain position from iter domain position");
+ }
}
+ auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
+ auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
if (isInsertSliceOp) {
- auto newForOp = cast<scf::ForOp>(newInnerMostLoop);
+ auto newForOp = cast<scf::ForOp>(newLoopOp);
fixTerminatorSCFYield(
- rewriter, newForOp, tileAndFuseResult->tiledOps[0]->getResults(),
- allResultOffsets, allResultSizes,
- newForOp.getBody()->getArguments().take_back(newInitAppend.size()));
+ rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
+ newForOp.getBody()->getArguments().drop_front(1 + initSize));
} else {
- auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
+ auto newForallOp = cast<scf::ForallOp>(newLoopOp);
fixTerminatorSCFInParallel(
rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
- allResultOffsets, allResultSizes,
- newForallOp.getBody()->getArguments().take_back(newInitAppend.size()));
+ arrayRefOffsets, arrayRefSizes,
+ newForallOp.getBody()->getArguments().drop_front(rank + initSize));
}
- newOuterLoops.push_back(cast<LoopLikeOpInterface>(newInnerMostLoop));
-
- // 11.a. Reconstruct terminator of outer loop by inner loop.
- auto outerCandidateIter = candidateSliceOpList.rbegin();
- for (auto [outerLoop, innerLoop] :
- llvm::zip_equal(MutableArrayRef(newOuterLoops).drop_back(),
- MutableArrayRef(newOuterLoops).drop_front())) {
- // 11.b. Create insertSliceOp according outer candidateSliceOp
- if (outerCandidateIter != candidateSliceOpList.rend() &&
- outerCandidateIter->getOperation()
- ->getParentOfType<LoopLikeOpInterface>() == outerLoop) {
- if (auto forallOp = dyn_cast<scf::ForallOp>(outerLoop.getOperation())) {
- rewriter.setInsertionPoint(forallOp.getTerminator());
- } else {
- rewriter.setInsertionPointAfter(*outerCandidateIter);
- }
-
- if (failed(computeAllResultTileForOpGivenOperandSliceOp(
- rewriter, clonedConsumerOp, operandNumber, *outerCandidateIter,
- allResultOffsets, allResultSizes))) {
- return failure();
- }
-
- if (auto forOp = dyn_cast<scf::ForOp>(outerLoop.getOperation())) {
- fixTerminatorSCFYield(
- rewriter, forOp,
- innerLoop->getResults().take_back(newInitAppend.size()),
- allResultOffsets, allResultSizes,
- forOp.getBody()->getArguments().take_back(newInitAppend.size()));
- } else if (auto forallOp =
- dyn_cast<scf::ForallOp>(outerLoop.getOperation())) {
- fixTerminatorSCFInParallel(
- rewriter, forallOp,
- innerLoop->getResults().take_back(newInitAppend.size()),
- allResultOffsets, allResultSizes,
- forallOp.getBody()->getArguments().take_back(newInitAppend.size()));
- }
- outerCandidateIter++;
- } else {
- // 11.c. Yield additional new appended results of innerLoop
- assert(isa<scf::ForOp>(outerLoop));
- auto forOp = cast<scf::ForOp>(outerLoop);
- auto outerLoopYield =
- cast<scf::YieldOp>(forOp.getBody()->getTerminator());
- SmallVector<Value> newYields =
- llvm::to_vector(outerLoopYield.getOperands());
- ValueRange additionalYields =
- innerLoop->getResults().take_back(newInitAppend.size());
- newYields.append(additionalYields.begin(), additionalYields.end());
- rewriter.setInsertionPoint(outerLoopYield);
- rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
- }
+ // 12. Replace the result of scf loop and consumer op with new loop's results.
+ for (auto &&[oldResult, newResult] :
+ llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
+ rewriter.replaceAllUsesWith(oldResult, newResult);
}
- // 12. Replace the result of consumer op with new outerMost loop's
- // results.
for (auto &&[oldResult, newResult] :
llvm::zip(consumerOp->getResults(),
- newOuterLoops.front()->getResults().take_back(
- newInitAppend.size()))) {
+ newLoopOp->getResults().drop_front(initSize))) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
- // 13. Need to erase the cloned consumer op.
+ // 13. Need to erase the old scf loop and the cloned consumer op.
+ rewriter.eraseOp(oldLoopOp);
rewriter.eraseOp(clonedConsumerOp);
return scf::SCFFuseConsumerOfSliceResult{
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index d60d5f4fd7b3c..400b558e37fcd 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -315,99 +315,3 @@ module attributes {transform.with_named_sequence} {
// CHECK: }
// CHECK: }
// CHECK: return %[[FINAL_RESULT]]#1 :
-
-// -----
-
-#map = affine_map<(d0) -> (d0 * 128)>
-module {
- func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
- %c0 = arith.constant 0 : index
- %c64 = arith.constant 64 : index
- %c128 = arith.constant 128 : index
- %cst = arith.constant 0.000000e+00 : f32
- %dest0 = tensor.empty() : tensor<256x256xf32>
- %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
- %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
- %iv0 = affine.apply #map(%arg3)
- %iv1 = affine.apply #map(%arg4)
- %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
- %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
- %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
- %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
- %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
- %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
- %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
- %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
- %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
- %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
- scf.yield %insert_slice : tensor<128x128xf32>
- }
- scf.yield %3 : tensor<128x128xf32>
- }
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
- }
- }
- %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
- return %5 : tensor<256x256xf32>
- }
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %slice_op
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
-}
-
-// CHECK: #[[MAP0:.*]] = affine_map<(d0) -> (d0 * 128)>
-// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 128)>
-// CHECK: func.func @fuse_tilable_consumer_nested_scf_loop(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
-// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
-// CHECK: %[[dest1:.*]] = linalg.fill
-// CHECK-SAME: outs(%[[dest0]] :
-// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
-// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG0:.*]] = %[[dest1]], %[[SECOND_OUT_ARG0:.*]] = %[[dest0]])
-// CHECK-SAME: {
-// CHECK: %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
-// CHECK: %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
-// CHECK: %[[MAT_OUT_SLICE0:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
-// CHECK: %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
-// CHECK: %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
-// CHECK: %[[ADD_OUT_SLICE0:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
-// CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV3:.*]] = %[[C0]]
-// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[MAT_OUT_SLICE0]], %[[SECOND_OUT_ARG1:.*]] = %[[ADD_OUT_SLICE0]])
-// CHECK-SAME: {
-// CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV4:.*]] = %[[C0]]
-// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
-// CHECK-SAME: {
-// CHECK: %[[MAT_OUT_SLICE1:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
-// CHECK: %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
-// CHECK: %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
-// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
-// CHECK-SAME: outs(%[[MAT_OUT_SLICE1]] :
-// CHECK: %[[REAL_SECOND_IV1:.*]] = affine.apply #[[MAP1]](%[[IV3]], %[[IV1]])
-// CHECK: %[[REAL_SECOND_IV2:.*]] = affine.apply #[[MAP1]](%[[IV4]], %[[IV2]])
-// CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[REAL_SECOND_IV1]], %[[REAL_SECOND_IV2]]] [64, 64] [1, 1]
-// CHECK: %[[ADD_OUT_SLICE1:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
-// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
-// CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
-// CHECK-SAME: outs(%[[ADD_OUT_SLICE1]] :
-// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
-// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
-// CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
-// CHECK: }
-// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
-// CHECK: }
-// CHECK: scf.forall.in_parallel {
-// CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]]#1 into %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
-// CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]]#0 into %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
-// CHECK: }
-// CHECK: }
-// CHECK: return %[[FINAL_RESULT]]#1 :
>From a18d528715b45abb99c5b67df76c512f159ee7f5 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Tue, 25 Jun 2024 01:58:59 -0700
Subject: [PATCH 3/3] extend consumer fuse to nested scf loop (v2)
---
.../SCF/Transforms/TileUsingInterface.h | 4 +
.../SCF/Transforms/TileUsingInterface.cpp | 271 ++++++++++++++++--
.../tile-and-fuse-consumer.mlir | 98 ++++++-
3 files changed, 350 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index d68ca11207376..1c9a1ff8c64f5 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -254,6 +254,10 @@ struct SCFFuseConsumerOfSliceResult {
*tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
SmallVector<Operation *> tiledOps;
};
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceImpl(RewriterBase &rewriter,
+ Operation *candidateSliceOp);
+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a1392813d6de3..952ddff64a24e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1228,12 +1228,24 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
/// failure otherwise.
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
Block *containingOpBlock) {
- // Step 1. Check that the value has exactly one use.
- if (!llvm::hasSingleElement(val.getUses()))
+ // Step 1. Check that the value has exactly one use except for scf.yield.
+ OpOperand *operand = nullptr;
+ for (auto &use : val.getUses()) {
+ Operation *user = use.getOwner();
+ if (isa<tensor::InsertSliceOp>(user) ||
+ isa<tensor::ParallelInsertSliceOp>(user))
+ continue;
+ else {
+ if (operand)
+ return failure();
+ else
+ operand = &use;
+ }
+ }
+ if (!operand)
return failure();
// Step 2. Get uses.
- OpOperand &operand = (*val.getUses().begin());
- Operation *consumerOp = operand.getOwner();
+ Operation *consumerOp = operand->getOwner();
// TODO: We have to init result of consumer before scf.for, use
// DestinationStyleOpInterface to get result shape from init for now.
// Add support for other op such as op has InferTypeOpInterface.
@@ -1242,7 +1254,56 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
return failure();
if (containingOpBlock != consumerOp->getBlock())
return failure();
- return &operand;
+ return operand;
+}
+
+/// Recursively find the outer nest loops of given loop(included) while the
+/// predict function succeed, sorted from outer to inner.
+///
+/// @param loop: target loop, note that this loop will be also included. I.e.
+/// if no other nest loops were found, just return itself.
+/// @param pred: predict function, the termination condition of recursive
+/// process.
+/// @return Outer Nest Loops: nest loops outside given target loop(included).
+///
+/// E.g.
+///
+/// ```
+/// %0 = scf.for()
+/// %1 = scf.for()
+/// %2 = scf.for()
+/// ```
+///
+/// If `%2 = scf.for` is given without specific prediction function, this
+/// function will return three nest loops: %0 + %1 + %2.
+static SmallVector<LoopLikeOpInterface>
+getOuterNestLoopsWhile(LoopLikeOpInterface loop,
+ std::function<LogicalResult(LoopLikeOpInterface)> pred) {
+ SmallVector<LoopLikeOpInterface> nestLoops = {loop};
+ auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
+ while (outerLoop && succeeded(pred(outerLoop))) {
+ nestLoops.push_back(outerLoop);
+ outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp());
+ }
+ // sorted from outer to inner
+ return {nestLoops.rbegin(), nestLoops.rend()};
+}
+
+/// Check if it is the ForOp that yield the result of inner loop
+static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) {
+ if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
+ for (auto &&[index, op] :
+ llvm::enumerate(forOp.getBody()->getOperations())) {
+ // If the orderIndex of inner loop is the last second one before the
+ // yieldOp of ForOp, the given loop must yield the result of inner loop.
+ if (isa<LoopLikeOpInterface>(op)) {
+ return (index + 2) == forOp.getBody()->getOperations().size()
+ ? success()
+ : failure();
+ }
+ }
+ }
+ return failure();
}
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
@@ -1262,9 +1323,11 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
auto forOp = dyn_cast<scf::ForOp>(containingOp);
if (!forOp)
return failure();
- Value resultingValue = forOp->getResult(resultNumber);
+ LoopLikeOpInterface topLevelForOp =
+ getOuterNestLoopsWhile(forOp, isForOpYieldResultOfInnerLoop).front();
+ Value resultingValue = topLevelForOp->getResult(resultNumber);
- return getConsumerFromUses(resultingValue, containingOp->getBlock());
+ return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
}
/// Fetch the first untiled consumer of a scf.forall's result which is yielded
@@ -1383,8 +1446,8 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
- Operation *candidateSliceOp) {
+mlir::scf::tileAndFuseConsumerOfSliceImpl(RewriterBase &rewriter,
+ Operation *candidateSliceOp) {
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
candidateSliceOp))
return failure();
@@ -1418,22 +1481,27 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
if (isInsertSliceOp) {
auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
oldLoopOp = forOp;
- llvm::append_range(newOuts, forOp.getInits());
- oldLoopBody = forOp.getBody();
initSize = forOp.getInits().size();
} else {
auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
oldLoopOp = forallOp;
- llvm::append_range(newOuts, forallOp.getOutputs());
- oldLoopBody = forallOp.getBody();
initSize = forallOp.getOutputs().size();
rank = forallOp.getRank();
}
- if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+ Operation *oldTopLevelLoop = oldLoopOp;
+ SmallVector<LoopLikeOpInterface> oldNestedForOps, newNestedForOps;
+ if (isInsertSliceOp) {
+ oldNestedForOps =
+ getOuterNestLoopsWhile(cast<LoopLikeOpInterface>(oldTopLevelLoop),
+ isForOpYieldResultOfInnerLoop);
+ oldTopLevelLoop = oldNestedForOps.front();
+ }
+ if (failed(checkAssumptionForLoop(oldTopLevelLoop, consumerOp))) {
return rewriter.notifyMatchFailure(
- oldLoopOp, "containing loop op should either yield just one value or "
- "have the consumer op as its first user");
+ oldTopLevelLoop,
+ "containing loop op should either yield just one value or "
+ "have the consumer op as its first user");
}
OpBuilder::InsertionGuard g(rewriter);
@@ -1442,21 +1510,51 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
SmallVector<Value> dpsInits =
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+ if (llvm::is_contained(dpsInits, oldTopLevelLoop->getResult(resultNumber))) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
- newOuts.append(dpsInits);
+ SmallVector<Value> newInitAppend = dpsInits;
Location loc = oldLoopOp->getLoc();
// 3. Create new scf loop op.
rewriter.setInsertionPoint(consumerOp);
+
+ // 3.a Create new outer scf loops if necessary
+ bool isNestedForOps = isInsertSliceOp && oldNestedForOps.size() > 1;
+ if (isNestedForOps) {
+ for (auto &&[index, loopOp] :
+ llvm::enumerate(MutableArrayRef(oldNestedForOps).drop_back())) {
+ auto forOp = cast<scf::ForOp>(loopOp);
+ SmallVector<Value> newInits;
+ newInits = llvm::to_vector(forOp.getInits());
+ newInits.append(newInitAppend.begin(), newInitAppend.end());
+ auto newLoop = rewriter.create<scf::ForOp>(
+ forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+ forOp.getStep(), newInits);
+ newInitAppend = llvm::map_to_vector(
+ newLoop.getRegionIterArgs().take_back(newInitAppend.size()),
+ [](BlockArgument bArg) -> Value { return bArg; });
+ rewriter.mergeBlocks(
+ forOp.getBody(), newLoop.getBody(),
+ newLoop.getBody()->getArguments().take_front(initSize + 1));
+ rewriter.replaceOp(
+ forOp, newLoop->getResults().take_front(forOp->getNumResults()));
+ newNestedForOps.push_back(newLoop);
+ rewriter.setInsertionPointAfter(oldNestedForOps[index + 1]);
+ }
+ }
+
+ // 3.b Create new inner most scf loop
Operation *newLoopOp = nullptr;
Block *newLoopBody = nullptr;
if (isInsertSliceOp) {
auto forOp = cast<scf::ForOp>(oldLoopOp);
+ llvm::append_range(newOuts, forOp.getInits());
+ newOuts.append(newInitAppend);
+ oldLoopBody = forOp.getBody();
auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(), newOuts);
@@ -1464,6 +1562,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
newLoopBody = newForOp.getBody();
} else {
auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+ llvm::append_range(newOuts, forallOp.getOutputs());
+ newOuts.append(newInitAppend);
+ oldLoopBody = forallOp.getBody();
auto newForallOp = rewriter.create<scf::ForallOp>(
loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
forallOp.getMixedStep(), newOuts, forallOp.getMapping());
@@ -1577,28 +1678,156 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
newForallOp.getBody()->getArguments().drop_front(rank + initSize));
}
- // 12. Replace the result of scf loop and consumer op with new loop's results.
+ // 12. Restore outer loops from inner to outer
+ if (isNestedForOps) {
+ newNestedForOps.push_back(cast<scf::ForOp>(newLoopOp));
+ for (auto [outerLoop, innerLoop] :
+ llvm::zip_equal(MutableArrayRef(newNestedForOps).drop_back(),
+ MutableArrayRef(newNestedForOps).drop_front())) {
+ auto forOp = cast<scf::ForOp>(outerLoop);
+ auto outerLoopYield =
+ cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+ SmallVector<Value> newYields =
+ llvm::to_vector(outerLoopYield.getOperands());
+ ValueRange additionalYields =
+ innerLoop->getResults().take_back(newInitAppend.size());
+ newYields.append(additionalYields.begin(), additionalYields.end());
+ rewriter.setInsertionPoint(outerLoopYield);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
+ }
+ }
+
+ // 13. Replace the result of scf loop and consumer op with new loop's results.
for (auto &&[oldResult, newResult] :
llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
+ Operation *newTopLevelLoop =
+ isNestedForOps ? newNestedForOps.front() : newLoopOp;
for (auto &&[oldResult, newResult] :
llvm::zip(consumerOp->getResults(),
- newLoopOp->getResults().drop_front(initSize))) {
+ newTopLevelLoop->getResults().drop_front(initSize))) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
- // 13. Need to erase the old scf loop and the cloned consumer op.
+ // 14. Need to erase the old scf loop and the cloned consumer op.
rewriter.eraseOp(oldLoopOp);
rewriter.eraseOp(clonedConsumerOp);
+ // 15. Need to erase the cloned insertSliceOp and unused extractSliceOp in
+ // avoid of complex domination analysis
+ assert(clonedInsertSliceOp->hasOneUse());
+ auto unUsedExtractOp =
+ cast<tensor::ExtractSliceOp>((*clonedInsertSliceOp->getUsers().begin()));
+ rewriter.eraseOp(unUsedExtractOp);
+ rewriter.eraseOp(clonedInsertSliceOp);
+
return scf::SCFFuseConsumerOfSliceResult{
consumerOpOperand,
&(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
tileAndFuseResult->tiledOps};
}
+/// Get the result of top-level loop which yields the target InsertSliceOp. E.g
+///
+/// ```
+/// %1 = scf.for
+/// %2 = scf.for
+/// %3 = scf.for
+/// ...
+/// %4 = insert
+/// yield %4
+/// %5 = insert %3
+/// yield %5
+/// yield %2
+/// ```
+///
+/// @param targetSliceOp: %4 = insert
+/// @param insertSliceOpChain: chain of all related insert sliceOp
+/// @return resultValue: %1
+static FailureOr<Value> getResultOfTopLevelLoopYieldInsertSliceOp(
+ Operation *targetSliceOp,
+ SmallVectorImpl<OffsetSizeAndStrideOpInterface> &insertSliceOpChain,
+ int curDepth = 0, int maxDepth = 5) {
+ assert(isa<OffsetSizeAndStrideOpInterface>(targetSliceOp));
+ // Control recursive time in avoid of stack overflow
+ if (curDepth > maxDepth)
+ return failure();
+
+ insertSliceOpChain.push_back(
+ cast<OffsetSizeAndStrideOpInterface>(targetSliceOp));
+ Value resultOfLoop;
+ if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(targetSliceOp)) {
+ Value destValue = sliceOp.getDest();
+ auto iterArg = cast<BlockArgument>(destValue);
+ auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
+ if (!forallOp)
+ return failure();
+ resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+ } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(targetSliceOp)) {
+ Value resultValue = sliceOp.getResult();
+ for (auto &useOperand : resultValue.getUses()) {
+ if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+ if (llvm::detail::isPresent(resultOfLoop))
+ return failure();
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+ if (!forOp)
+ return failure();
+ resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+ }
+ }
+ }
+
+ if (!llvm::detail::isPresent(resultOfLoop))
+ return failure();
+
+ bool traverseUpperLoop;
+ do {
+ traverseUpperLoop = false;
+ for (OpOperand &useOperand : resultOfLoop.getUses()) {
+ if (auto sliceOp =
+ dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
+ return getResultOfTopLevelLoopYieldInsertSliceOp(
+ sliceOp, insertSliceOpChain, curDepth + 1);
+ } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+ // walk through outer loop
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+ if (!forOp)
+ return failure();
+ resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+ traverseUpperLoop = true;
+ break;
+ }
+ }
+ } while (traverseUpperLoop);
+ return resultOfLoop;
+}
+
+/// Fusing real consumer of a single slice even within complex nested loops via
+/// multiple application of `tileAndFuseConsumerOfSliceImpl`.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+ Operation *candidateSliceOp) {
+ SmallVector<OffsetSizeAndStrideOpInterface> sliceOpChain;
+ if (failed(getResultOfTopLevelLoopYieldInsertSliceOp(candidateSliceOp,
+ sliceOpChain)))
+ return failure();
+
+ FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResult;
+ // reverse from outer to inner
+ std::reverse(sliceOpChain.begin(), sliceOpChain.end());
+ // multiple application of `tileAndFuseConsumerOfSliceImpl`
+ for (auto &sliceOp : sliceOpChain) {
+ fuseConsumerResult = tileAndFuseConsumerOfSliceImpl(rewriter, sliceOp);
+ if (failed(fuseConsumerResult)) {
+ return rewriter.notifyMatchFailure(sliceOp,
+ "could not fuse consumer of sliceOp");
+ }
+ }
+ return fuseConsumerResult;
+}
+
//===----------------------------------------------------------------------===//
// lowerToLoopsUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 400b558e37fcd..8022d5af37078 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -45,12 +45,12 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: %[[MAT_OUT:.*]] = linalg.generic
// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
-// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
// CHECK-SAME: outs(%[[SLICE_OUT]] :
+// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
// CHECK: }
@@ -170,13 +170,13 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: %[[MAT_OUT:.*]] = linalg.generic
// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
-// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic
// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
// CHECK: %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] :
@@ -315,3 +315,97 @@ module attributes {transform.with_named_sequence} {
// CHECK: }
// CHECK: }
// CHECK: return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 128)>
+module {
+ func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dest0 = tensor.empty() : tensor<256x256xf32>
+ %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
+ %iv0 = affine.apply #map(%arg3)
+ %iv1 = affine.apply #map(%arg4)
+ %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
+ %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
+ %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
+ %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
+ %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
+ %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
+ %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
+ %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
+ %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
+ %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
+ scf.yield %insert_slice : tensor<128x128xf32>
+ }
+ scf.yield %3 : tensor<128x128xf32>
+ }
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
+ }
+ }
+ %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ return %5 : tensor<256x256xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %slice_op
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[MAP0:.*]] = affine_map<(d0) -> (d0 * 128)>
+// CHECK: func.func @fuse_tilable_consumer_nested_scf_loop(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+// CHECK: %[[dest1:.*]] = linalg.fill
+// CHECK-SAME: outs(%[[dest0]] :
+// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG0:.*]] = %[[dest1]], %[[SECOND_OUT_ARG0:.*]] = %[[dest0]])
+// CHECK-SAME: {
+// CHECK: %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
+// CHECK: %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
+// CHECK: %[[MAT_OUT_SLICE0:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+// CHECK: %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
+// CHECK: %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
+// CHECK: %[[ADD_OPERAND2_SLICE0:.*]] = tensor.extract_slice %[[ARG2]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+// CHECK: %[[ADD_OUT_SLICE0:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+// CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV3:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[MAT_OUT_SLICE0]], %[[SECOND_OUT_ARG1:.*]] = %[[ADD_OUT_SLICE0]])
+// CHECK-SAME: {
+// CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV4:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE1:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+// CHECK: %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
+// CHECK: %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
+// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE1]] :
+// CHECK: %[[ADD_OPERAND2_SLICE1:.*]] = tensor.extract_slice %[[ADD_OPERAND2_SLICE0]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+// CHECK: %[[ADD_OUT_SLICE1:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
+// CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE1]] :
+// CHECK-SAME: outs(%[[ADD_OUT_SLICE1]] :
+// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+// CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
+// CHECK: }
+// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
+// CHECK: }
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]]#1 into %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]]#0 into %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#1 :
\ No newline at end of file
More information about the Mlir-commits
mailing list