[Mlir-commits] [mlir] [mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate `scf::forall`. (PR #67083)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 21 18:55:37 PDT 2023
https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/67083
Similar to `scf::tileUsingSCFForOp` that is a method that tiles
operations that implement the `TilingInterface`, using `scf.for`
operations, this method introduces tiling of operations using
`scf.forall`. Most of this implementation is derived from
`linalg::tileToForallOp` method. Eventually that method will either be
deprecated or moved to use the method introduced here.
>From 5f668bb3f4b1305433416edd9db89b1b0ebd2465 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Thu, 21 Sep 2023 16:22:32 -0700
Subject: [PATCH 1/2] [mlir][TilingInterface] NFC code changes separated out
from introduction of `scf::tileUsingSCFForallop`.
This patch contains NFC changes that are precursor to the introduction
of `scf::tileUsingSCFForallOp` method.
---
.../SCF/Transforms/TileUsingInterface.h | 4 +-
.../TransformOps/LinalgTransformOps.cpp | 14 +--
.../SCF/Transforms/TileUsingInterface.cpp | 112 +++++++++++-------
.../TilingInterface/TestTilingInterface.cpp | 99 ++++++++--------
4 files changed, 125 insertions(+), 104 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index ca641c596c7b7bb..9f49d97e141e0c8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -60,7 +60,7 @@ struct SCFTilingResult {
/// of the last op.
SmallVector<Operation *> tiledOps;
/// The `scf.for` operations that iterate over the tiles.
- SmallVector<scf::ForOp> loops;
+ SmallVector<Operation *> loops;
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
@@ -160,7 +160,7 @@ struct SCFTileAndFuseResult {
/// generated operation.
llvm::SetVector<Operation *> tiledAndFusedOps;
/// The `scf.for` operations that iterate over the tiles.
- SmallVector<scf::ForOp> loops;
+ SmallVector<Operation *> loops;
/// The replacement values to use for the tiled and fused operations.
llvm::DenseMap<Value, Value> replacements;
};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1819ca614a060fd..ca3db7401e38caa 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -434,16 +434,12 @@ static LogicalResult applyTilingToAll(
SmallVector<Operation *> opsToReplace{target};
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
for (Operation *toReplace : opsToReplace) {
- SmallVector<Value> replacements;
- replacements.reserve(toReplace->getNumResults());
- for (OpResult res : toReplace->getResults()) {
- auto it = tiledResults->replacements.find(res);
- if (it == tiledResults->replacements.end())
- replacements.push_back(res);
- else
- replacements.push_back(it->getSecond());
+ for (OpResult res : toReplace->getResults())
+ if (auto replacement = tiledResults->replacements.lookup(res))
+ rewriter.replaceAllUsesWith(res, replacement);
+ if (toReplace->use_empty()) {
+ rewriter.eraseOp(toReplace);
}
- rewriter.replaceOp(toReplace, replacements);
}
// Report back the relevant handles to the transform op.
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 6cfba3fef15ebda..6bde60ad757a73b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -55,6 +55,30 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
return filledVector;
}
+/// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
+template <typename SrcOpTy>
+static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
+ return llvm::to_vector(
+ llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
+}
+template <typename SrcOpTy>
+static SmallVector<Operation *>
+getAsOperations(const SmallVector<SrcOpTy> &ops) {
+ return getAsOperations(ArrayRef<SrcOpTy>(ops));
+}
+
+/// Convert a list of `Operation *` to a list of `DstOpTy`
+template <typename DstOpTy>
+static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
+ return llvm::to_vector(
+ llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
+}
+template <typename DstOpTy>
+static SmallVector<DstOpTy>
+castToTypedOperations(const SmallVector<Operation *> &ops) {
+ return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
+}
+
//===----------------------------------------------------------------------===//
// tileUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
@@ -77,10 +101,9 @@ static bool tileDividesIterationDomain(Range loopRange) {
/// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
Range loopRange, Value iv,
- Value tileSize) {
- std::optional<int64_t> ts = getConstantIntValue(tileSize);
- if (ts && ts.value() == 1)
- return getAsOpFoldResult(tileSize);
+ OpFoldResult tileSize) {
+ if (isConstantIntValue(tileSize, 1))
+ return tileSize;
if (tileDividesIterationDomain(
Range{loopRange.offset, loopRange.size, tileSize}))
@@ -295,8 +318,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
}
- scf::SCFTilingResult tilingResult;
SmallVector<OpFoldResult> offsets, sizes;
+ SmallVector<scf::ForOp> forLoops;
{
// If there is an interchange specified, permute the iteration domain and
// the tile sizes.
@@ -319,8 +342,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
// 3. Materialize an empty loop nest that iterates over the tiles. These
// loops for now do not return any values even if the original operation has
// results.
- tilingResult.loops = generateTileLoopNest(
- rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
+ forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
+ tileSizeVector, offsets, sizes);
if (!interchangeVector.empty()) {
auto inversePermutation = invertPermutationVector(interchangeVector);
@@ -330,30 +353,30 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
}
LLVM_DEBUG({
- if (!tilingResult.loops.empty()) {
+ if (!forLoops.empty()) {
llvm::dbgs() << "LoopNest shell :\n";
- tilingResult.loops.front().dump();
+ forLoops.front().dump();
llvm::dbgs() << "\n";
}
});
// 4. Generate the tiled implementation within the inner most loop.
- if (!tilingResult.loops.empty())
- rewriter.setInsertionPoint(
- tilingResult.loops.back().getBody()->getTerminator());
+ if (!forLoops.empty())
+ rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator());
FailureOr<TilingResult> tiledImplementation =
op.getTiledImplementation(rewriter, offsets, sizes);
- tilingResult.tiledOps.append(tiledImplementation->tiledOps);
+
if (op->getNumResults() == 0) {
- // nothing more to do.
- return tilingResult;
+ return scf::SCFTilingResult{
+ tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
}
// If loops are empty, the tiled op is used as the replacement for the untiled
// op.
- if (tilingResult.loops.empty()) {
- tilingResult.replacements = tiledImplementation->tiledValues;
- return tilingResult;
+ if (forLoops.empty()) {
+ return scf::SCFTilingResult{tiledImplementation->tiledOps,
+ getAsOperations(forLoops),
+ tiledImplementation->tiledValues};
}
// 5. Yield all the results of the tiled operation. The surrounding loop
@@ -377,18 +400,18 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
destinationTensors)))
return rewriter.notifyMatchFailure(op, "failed to get destinations");
- tilingResult.replacements = yieldTiledValues(
+ SmallVector<Value> replacements = yieldTiledValues(
rewriter, destinationTensors, tiledImplementation.value(),
- resultOffsetsList, resultSizesList, tilingResult.loops);
-
+ resultOffsetsList, resultSizesList, forLoops);
LLVM_DEBUG({
- if (!tilingResult.loops.empty()) {
+ if (!forLoops.empty()) {
llvm::dbgs() << "After tiled implementation :\n";
- tilingResult.loops.front().dump();
+ forLoops.front().dump();
llvm::dbgs() << "\n";
}
});
- return tilingResult;
+ return scf::SCFTilingResult{tiledImplementation->tiledOps,
+ getAsOperations(forLoops), replacements};
}
FailureOr<scf::SCFReductionTilingResult>
@@ -466,6 +489,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
results.mergeOp = mergeOp;
return results;
}
+
//===----------------------------------------------------------------------===//
// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
@@ -636,7 +660,9 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
}
// 1. First tile the consumer.
- scf::SCFTileAndFuseResult tileAndFuseResult;
+ SmallVector<scf::ForOp> forLoops;
+ SetVector<Operation *> fusedProducers, tiledAndFusedOps;
+ DenseMap<Value, Value> replacements;
llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
{
FailureOr<scf::SCFTilingResult> tilingResult =
@@ -644,20 +670,21 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
if (failed(tilingResult))
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
for (auto *tiledOp : tilingResult->tiledOps)
- tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
- tileAndFuseResult.loops = std::move(tilingResult->loops);
- for (const auto &result : llvm::enumerate(
- llvm::zip(consumer->getResults(), tilingResult->replacements))) {
- tileAndFuseResult.replacements[std::get<0>(result.value())] =
- std::get<1>(result.value());
+ tiledAndFusedOps.insert(tiledOp);
+ forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
+ for (auto [index, origValue, replacement] :
+ llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
+ replacements[origValue] = replacement;
yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
- result.index())] = result.index();
+ index)] = index;
}
}
// If there are no loops generated, fusion is immaterial.
- if (tileAndFuseResult.loops.empty())
- return tileAndFuseResult;
+ if (forLoops.empty()) {
+ return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
+ getAsOperations(forLoops), replacements};
+ }
// 2. Typically, the operands of the tiled operation are slices of the
// operands of the untiled operation. These are expressed in IR using
@@ -674,7 +701,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
};
std::deque<tensor::ExtractSliceOp> candidates;
- addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
+ addCandidateSlices(tiledAndFusedOps.back(), candidates);
OpBuilder::InsertionGuard g(rewriter);
while (!candidates.empty()) {
// Traverse the slices in BFS fashion.
@@ -684,19 +711,20 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
// The operands of the fused producer might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
- std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
- tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
- tileAndFuseResult.loops);
- if (!fusedProducer)
+ std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
+ tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
+ if (!fusedResult)
continue;
if (Operation *tiledAndFusedOp =
- fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
- tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
+ fusedResult->tiledAndFusedProducer.getDefiningOp()) {
+ fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
+ tiledAndFusedOps.insert(tiledAndFusedOp);
addCandidateSlices(tiledAndFusedOp, candidates);
}
}
- return tileAndFuseResult;
+ return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
+ getAsOperations(forLoops), replacements};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 2fcc7bcadb60450..5e831c2c32562fb 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -37,51 +37,51 @@ using namespace mlir;
namespace {
/// Marker used as attribute name in generated Linalg rewriting transformations.
-const StringLiteral kLinalgTransformMarker = "__internal_linalg_transform__";
+const StringLiteral kTransformMarker = "__internal_linalg_transform__";
/// Helper class to control application of linalg transformation patterns.
/// Control comes in 2 forms:
/// 1. attribute matching and setting behavior using the attribute named
-/// `kLinalgTransformMarker`. This can be used to build a state machine
+/// `kTransformMarker`. This can be used to build a state machine
/// using attributes and incrementally applying patterns to advance states.
/// 2. filter function, which is a simple lambda on the Operation* that
/// returns a LogicalResult.
-struct LinalgTransformationFilter {
+struct TransformationFilter {
using FilterFunction = std::function<LogicalResult(Operation *)>;
- explicit LinalgTransformationFilter(
+ explicit TransformationFilter(
ArrayRef<StringAttr> matchDisjunction = {},
std::optional<StringAttr> replacement = std::nullopt);
- explicit LinalgTransformationFilter(
+ explicit TransformationFilter(
const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction = {},
std::optional<StringAttr> replacement = std::nullopt);
- LinalgTransformationFilter(LinalgTransformationFilter &&) = default;
- LinalgTransformationFilter(const LinalgTransformationFilter &) = default;
+ TransformationFilter(TransformationFilter &&) = default;
+ TransformationFilter(const TransformationFilter &) = default;
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
- void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
- Operation *op) const;
+ void replaceTransformationFilter(PatternRewriter &rewriter,
+ Operation *op) const;
- LinalgTransformationFilter &addFilter(const FilterFunction &f) {
+ TransformationFilter &addFilter(const FilterFunction &f) {
if (f)
filters.push_back(f);
return *this;
}
template <typename... OpTypes>
- LinalgTransformationFilter &addOpFilter() {
+ TransformationFilter &addOpFilter() {
return addFilter(
[](Operation *op) { return success(isa<OpTypes...>(op)); });
}
- LinalgTransformationFilter &addOpNameFilter(StringRef opName) {
+ TransformationFilter &addOpNameFilter(StringRef opName) {
return addFilter([opName](Operation *op) {
return success(op->getName().getStringRef() == opName);
});
}
- LinalgTransformationFilter &setMatchByDefault() {
+ TransformationFilter &setMatchByDefault() {
matchByDefault = true;
return *this;
}
@@ -95,20 +95,19 @@ struct LinalgTransformationFilter {
bool matchByDefault;
};
-LinalgTransformationFilter::LinalgTransformationFilter(
+TransformationFilter::TransformationFilter(
ArrayRef<StringAttr> matchDisjunction,
std::optional<StringAttr> replacement)
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
replacement(replacement), matchByDefault(false) {}
-LogicalResult
-LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter,
- Operation *op) const {
+LogicalResult TransformationFilter::checkAndNotify(PatternRewriter &rewriter,
+ Operation *op) const {
if (llvm::any_of(filters,
[&](const FilterFunction &f) { return failed(f(op)); }))
return failure();
- auto attr = op->template getAttrOfType<StringAttr>(kLinalgTransformMarker);
+ auto attr = op->template getAttrOfType<StringAttr>(kTransformMarker);
if (!attr) {
// 1. Has no filter case and matchDisjunction is empty.
@@ -134,12 +133,12 @@ LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter,
});
}
-void LinalgTransformationFilter::replaceLinalgTransformationFilter(
+void TransformationFilter::replaceTransformationFilter(
PatternRewriter &rewriter, Operation *op) const {
if (replacement.has_value())
- op->setAttr(kLinalgTransformMarker, *replacement);
+ op->setAttr(kTransformMarker, *replacement);
else
- op->removeAttr(rewriter.getStringAttr(kLinalgTransformMarker));
+ op->removeAttr(rewriter.getStringAttr(kTransformMarker));
}
/// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
@@ -147,18 +146,17 @@ void LinalgTransformationFilter::replaceLinalgTransformationFilter(
/// using a `filter` to avoid recursive application.
struct TestTileUsingSCFForOp
: public OpInterfaceRewritePattern<TilingInterface> {
- TestTileUsingSCFForOp(
- MLIRContext *context, scf::SCFTilingOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
+ TestTileUsingSCFForOp(MLIRContext *context, scf::SCFTilingOptions options,
+ TransformationFilter filter = TransformationFilter(),
+ PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
options(std::move(options)), filter(std::move(filter)) {}
/// Construct a generic pattern applied to `opName`.
- TestTileUsingSCFForOp(
- StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
+ TestTileUsingSCFForOp(StringRef opName, MLIRContext *context,
+ scf::SCFTilingOptions options,
+ TransformationFilter filter = TransformationFilter(),
+ PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
options(std::move(options)), filter(std::move(filter)) {}
@@ -179,13 +177,13 @@ struct TestTileUsingSCFForOp
}
for (auto *tiledOp : tilingResult->tiledOps)
- filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
+ filter.replaceTransformationFilter(rewriter, tiledOp);
return success();
}
private:
scf::SCFTilingOptions options;
- LinalgTransformationFilter filter;
+ TransformationFilter filter;
};
/// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern
@@ -196,7 +194,7 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
: public OpInterfaceRewritePattern<TilingInterface> {
TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp(
MLIRContext *context, scf::SCFTileAndFuseOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ TransformationFilter filter = TransformationFilter(),
PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
options(std::move(options)), filter(std::move(filter)) {}
@@ -205,7 +203,7 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp(
StringRef opName, MLIRContext *context,
scf::SCFTileAndFuseOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ TransformationFilter filter = TransformationFilter(),
PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
options(std::move(options)), filter(std::move(filter)) {}
@@ -229,14 +227,14 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
}
rewriter.replaceOp(op, replacements);
- filter.replaceLinalgTransformationFilter(
+ filter.replaceTransformationFilter(
rewriter, tileAndFuseResult->tiledAndFusedOps.front());
return success();
}
private:
scf::SCFTileAndFuseOptions options;
- LinalgTransformationFilter filter;
+ TransformationFilter filter;
};
/// Pattern to tile a consumer and fuse producer with it
@@ -254,7 +252,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
TestTileConsumerFuseAndYieldProducerUsingSCFForOp(
MLIRContext *context, scf::SCFTilingOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ TransformationFilter filter = TransformationFilter(),
PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
options(std::move(options)), filter(std::move(filter)) {}
@@ -302,6 +300,8 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
std::deque<tensor::ExtractSliceOp> candidates;
addCandidateSlices(tilingResult->tiledOps.back(), candidates);
OpBuilder::InsertionGuard g(rewriter);
+ auto forLoops = llvm::to_vector(llvm::map_range(
+ tilingResult->loops, [](auto op) { return cast<scf::ForOp>(op); }));
while (!candidates.empty()) {
// Traverse the slices in BFS fashion.
tensor::ExtractSliceOp candidateSliceOp = candidates.front();
@@ -309,8 +309,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
// Materialize the slice of the producer in place.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
- tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
- tilingResult->loops);
+ tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
if (!fusedProducer)
continue;
@@ -318,11 +317,10 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
// to be yielded from within the tiled loop.
OpResult untiledProducer = fusedProducer->origProducer;
if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) {
- return !isIgnoredUser(user, tilingResult->loops.front());
+ return !isIgnoredUser(user, forLoops.front());
})) {
yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
- fusedProducer.value(),
- tilingResult->loops);
+ fusedProducer.value(), forLoops);
yieldedValuesToOrigValues.push_back(untiledProducer);
}
@@ -332,7 +330,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
addCandidateSlices(fusedProducerOp, candidates);
}
- scf::ForOp outermostLoop = tilingResult->loops.front();
+ scf::ForOp outermostLoop = forLoops.front();
for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) {
Value replacement = outermostLoop.getResult(index);
rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
@@ -340,8 +338,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
});
}
rewriter.eraseOp(rootOp);
- filter.replaceLinalgTransformationFilter(rewriter,
- tilingResult->tiledOps.back());
+ filter.replaceTransformationFilter(rewriter, tilingResult->tiledOps.back());
return success();
}
@@ -370,7 +367,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
}
scf::SCFTilingOptions options;
- LinalgTransformationFilter filter;
+ TransformationFilter filter;
};
/// Pattern to lower operations that implement the `TilingInterface` to
@@ -453,8 +450,8 @@ static void addPatternForTiling(MLIRContext *context,
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(context, tileSizes);
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
- LinalgTransformationFilter filter(StringAttr::get(context, filterName),
- StringAttr::get(context, "tiled"));
+ TransformationFilter filter(StringAttr::get(context, filterName),
+ StringAttr::get(context, "tiled"));
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
}
@@ -467,8 +464,8 @@ static void addPatternForTileFuseAndYield(MLIRContext *context,
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(context, tileSizes);
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
- LinalgTransformationFilter filter(StringAttr::get(context, filterName),
- StringAttr::get(context, "tiled"));
+ TransformationFilter filter(StringAttr::get(context, filterName),
+ StringAttr::get(context, "tiled"));
patterns.add<TestTileConsumerFuseAndYieldProducerUsingSCFForOp>(
context, tilingOptions, filter);
}
@@ -483,8 +480,8 @@ static void addPatternForTileAndFuse(MLIRContext *context,
getAsIndexOpFoldResult(context, tileSizes);
tileAndFuseOptions.tilingOptions.setTileSizes(tileSizesOfr)
.setInterchange(interchange);
- LinalgTransformationFilter filter(StringAttr::get(context, filterName),
- StringAttr::get(context, "tiled"));
+ TransformationFilter filter(StringAttr::get(context, filterName),
+ StringAttr::get(context, "tiled"));
patterns.add<TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp>(
context, tileAndFuseOptions, filter);
}
>From 7ec04bc656697e6a85727e82bd75523b3194f497 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Thu, 21 Sep 2023 16:24:11 -0700
Subject: [PATCH 2/2] [mlir][TilingInterface] Add `scf::tileUsingSCFForallOp`
method to tile using the interface to generate `scf::forall`.
Similar to `scf::tileUsingSCFForOp` that is a method that tiles
operations that implement the `TilingInterface`, using `scf.for`
operations, this method introduces tiling of operations using
`scf.forall`. Most of this implementation is derived from
`linalg::tileToForallOp` method. Eventually that method will either be
deprecated or moved to use the method introduced here.
---
.../SCF/Transforms/TileUsingInterface.h | 17 +++
.../SCF/Transforms/TileUsingInterface.cpp | 133 ++++++++++++++++++
.../TilingInterface/tile-using-scfforall.mlir | 37 +++++
.../TilingInterface/TestTilingInterface.cpp | 69 +++++++++
4 files changed, 256 insertions(+)
create mode 100644 mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 9f49d97e141e0c8..06cce19894e9f5a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -51,6 +51,17 @@ struct SCFTilingOptions {
interchangeVector = llvm::to_vector(interchange);
return *this;
}
+
+ /// Specify mapping of loops to devices. This is only respected when the loop
+ /// constructs support such a mapping (like `scf.forall`). Will be ignored
+ /// when using loop constructs that dont support such a mapping (like
+ /// `scf.for`)
+ SmallVector<Attribute> mappingVector = {};
+ SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
+ mappingVector = llvm::to_vector(
+ llvm::map_range(mapping, [](auto attr) -> Attribute { return attr; }));
+ return *this;
+ }
};
/// Transformation information returned after tiling.
@@ -82,6 +93,12 @@ struct SCFTileAndFuseOptions {
}
};
+/// Method to tile and op that implements the `TilingInterface` using
+/// `scf.forall`.
+FailureOr<SCFTilingResult>
+tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+ const SCFTilingOptions &options);
+
/// Fuse the producer of the source of `candidateSliceOp` by computing the
/// required slice of the producer in-place. Note that the method
/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 6bde60ad757a73b..9054f7bcdde7e15 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -121,6 +121,24 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
}
+/// Clones the operation and updates the destination if the operation
+/// implements the `DestinationStyleOpInterface`.
+static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
+ Operation *op,
+ ValueRange newDestArgs) {
+ Operation *clonedOp = rewriter.clone(*op);
+ if (auto destinationStyleOp =
+ dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
+ // Note that this is assuming that
+ auto [start, end] = destinationStyleOp.getDpsInitsPositionRange();
+ assert((end - start == newDestArgs.size()) &&
+ "expected as many new destination args as number of inits of the "
+ "operation");
+ clonedOp->setOperands(start, end - start, newDestArgs);
+ }
+ return clonedOp;
+}
+
/// Generate an empty loop nest that represents the tiled loop nest shell.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
@@ -727,6 +745,121 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
getAsOperations(forLoops), replacements};
}
+//===----------------------------------------------------------------------===//
+// tileUsingSCFForAllOp implementation.
+//===----------------------------------------------------------------------===//
+
+FailureOr<scf::SCFTilingResult>
+mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+ const scf::SCFTilingOptions &options) {
+ Location loc = op->getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+
+ // 1. Get the range of loops that are represented by the operation.
+ SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
+ if (loopRanges.empty())
+ return op->emitOpError("expected non-empty loop ranges");
+ auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+ if (llvm::any_of(loopRanges, hasStrideOne))
+ return op->emitOpError("only stride-1 supported atm");
+
+ // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
+ // To make it easier, pad the tile sizes to loopRanges.size with value 0.
+ SmallVector<OpFoldResult> tileSizeVector =
+ options.tileSizeComputationFunction(rewriter, op);
+ tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
+
+ // 3. Build the offsets, sizes and steps for the tile and distributed loops.
+ SmallVector<OpFoldResult> lbs, ubs, steps;
+ for (auto [index, tileSize, loopRange] :
+ llvm::enumerate(tileSizeVector, loopRanges)) {
+ if (isConstantIntValue(tileSize, 0))
+ continue;
+ lbs.push_back(loopRange.offset);
+ ubs.push_back(loopRange.size);
+ steps.push_back(tileSize);
+ }
+
+ // 4. Gather destination tensors.
+ SmallVector<Value> dest;
+ if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
+ return op->emitOpError("failed to get destination tensors");
+
+ // 5. Build the device mapping attribute;
+ std::optional<ArrayAttr> mappingAttr;
+ if (!options.mappingVector.empty()) {
+ mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
+ }
+
+ // 6. Create the ForallOp. We don't use the lambda body-builder
+ // version because we require the use of RewriterBase in the body, so we
+ // manually move the insertion point to the body below.
+ auto forallOp =
+ rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
+
+ // 7. Get the tile offset and sizes.
+ rewriter.setInsertionPoint(forallOp.getTerminator());
+ SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+ tiledOffsets.reserve(loopRanges.size());
+ tiledSizes.reserve(loopRanges.size());
+ ValueRange ivs = forallOp.getInductionVars();
+ {
+ int materializedLoopNum = 0;
+ for (auto [index, tileSize, loopRange] :
+ llvm::enumerate(tileSizeVector, loopRanges)) {
+ if (isConstantIntValue(tileSize, 0)) {
+ tiledOffsets.push_back(loopRange.offset);
+ tiledSizes.push_back(loopRange.size);
+ continue;
+ }
+ Value iv = ivs[materializedLoopNum++];
+ tiledOffsets.push_back(iv);
+ tiledSizes.push_back(
+ getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+ }
+ }
+
+ // 8. Tile the operation. Clone the operation to allow fix up of destination
+ // operands
+ ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+ Operation *clonedOp =
+ cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
+ FailureOr<TilingResult> tilingResult =
+ cast<TilingInterface>(clonedOp).getTiledImplementation(
+ rewriter, tiledOffsets, tiledSizes);
+ if (failed(tilingResult))
+ return clonedOp->emitError("Failed to tile op: ");
+ rewriter.eraseOp(clonedOp);
+
+ // 9. Parallel insert back into the result tensor.
+ for (auto [index, tiledValue, destBBArg] :
+ llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
+ // 9.a. Partial subset information is inserted just before the terminator.
+ rewriter.setInsertionPoint(forallOp.getTerminator());
+
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
+ tiledSizes, resultOffsets,
+ resultSizes)))
+ return op->emitOpError("output offsets couldn't be calculated");
+ SmallVector<OpFoldResult> strides(resultSizes.size(),
+ rewriter.getIndexAttr(1));
+
+ // 5.b. Parallel insertions are inserted at the end of the combining
+ // terminator.
+ rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
+ rewriter.create<tensor::ParallelInsertSliceOp>(
+ loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
+ }
+
+ // 10. Return the tiling result;
+ return scf::SCFTilingResult{
+ tilingResult->tiledOps,
+ {forallOp.getOperation()},
+ llvm::to_vector(llvm::map_range(forallOp.getResults(),
+ [](auto val) -> Value { return val; }))};
+}
+
//===----------------------------------------------------------------------===//
// lowerToLoopsUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
new file mode 100644
index 000000000000000..f40374b7b5485da
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s
+
+func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul {__internal_linalg_transform__ = "simple_gemm"}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
+// CHECK: func.func @simple_matmul(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
+// CHECK-SAME: (0, 0) to (%[[M]], %[[N]]) step (10, 20) shared_outs(%[[INIT:.+]] = %[[ARG2]])
+// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[N]]]
+// CHECK: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1]
+// CHECK: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME: [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1]
+// CHECK: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME: outs(%[[INIT_TILE]] :
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+// CHECK: return %[[RESULT]]
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 5e831c2c32562fb..cad6955dece65ba 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -186,6 +186,51 @@ struct TestTileUsingSCFForOp
TransformationFilter filter;
};
+/// Pattern for testing `tileUsingSCFForallOp` (that tiles operations using
+/// the `TilingInterface` with `scf.forall` ops for iterating over the tiles)
+/// while using a `filter` to avoid recursive application.
+struct TestTileUsingSCFForallOp
+ : public OpInterfaceRewritePattern<TilingInterface> {
+ TestTileUsingSCFForallOp(MLIRContext *context, scf::SCFTilingOptions options,
+ TransformationFilter filter = TransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ options(std::move(options)), filter(std::move(filter)) {}
+
+ /// Construct a generic pattern applied to `opName`.
+ TestTileUsingSCFForallOp(StringRef opName, MLIRContext *context,
+ scf::SCFTilingOptions options,
+ TransformationFilter filter = TransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ options(std::move(options)), filter(std::move(filter)) {}
+
+ LogicalResult matchAndRewrite(TilingInterface op,
+ PatternRewriter &rewriter) const override {
+ if (failed(filter.checkAndNotify(rewriter, op)))
+ return failure();
+
+ FailureOr<scf::SCFTilingResult> tilingResult =
+ scf::tileUsingSCFForallOp(rewriter, op, options);
+ if (failed(tilingResult))
+ return rewriter.notifyMatchFailure(op, "failed to tile operation");
+
+ if (op->getNumResults()) {
+ rewriter.replaceOp(op, tilingResult->replacements);
+ } else {
+ rewriter.eraseOp(op);
+ }
+
+ for (auto *tiledOp : tilingResult->tiledOps)
+ filter.replaceTransformationFilter(rewriter, tiledOp);
+ return success();
+ }
+
+private:
+ scf::SCFTilingOptions options;
+ TransformationFilter filter;
+};
+
/// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern
/// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
/// ops for iterating over the tiles) while using a `filter` to avoid recursive
@@ -415,6 +460,12 @@ struct TestTilingInterfacePass
"Test tiling using TilingInterface with scf.for operations"),
llvm::cl::init(false)};
+ Option<bool> testTilingForAll{
+ *this, "tile-using-scf-forall",
+ llvm::cl::desc(
+ "Test tiling using TilingInterface with scf.forall operations"),
+ llvm::cl::init(false)};
+
Option<bool> testTileConsumerFuseAndYieldProducer{
*this, "tile-consumer-fuse-and-yield-producer-using-scf-for",
llvm::cl::desc(
@@ -455,6 +506,20 @@ static void addPatternForTiling(MLIRContext *context,
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
}
+static void addPatternForTilingUsingForall(MLIRContext *context,
+ RewritePatternSet &patterns,
+ StringRef filterName,
+ ArrayRef<int64_t> tileSizes,
+ ArrayRef<int64_t> interchange = {}) {
+ scf::SCFTilingOptions tilingOptions;
+ SmallVector<OpFoldResult> tileSizesOfr =
+ getAsIndexOpFoldResult(context, tileSizes);
+ tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
+ TransformationFilter filter(StringAttr::get(context, filterName),
+ StringAttr::get(context, "tiled"));
+ patterns.add<TestTileUsingSCFForallOp>(context, tilingOptions, filter);
+}
+
static void addPatternForTileFuseAndYield(MLIRContext *context,
RewritePatternSet &patterns,
StringRef filterName,
@@ -514,6 +579,10 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
addPatternForTiling(context, patterns, "simple_copy_memref", {10, 20});
return;
}
+ if (testTilingForAll) {
+ addPatternForTilingUsingForall(context, patterns, "simple_gemm", {10, 20});
+ return;
+ }
if (testTileConsumerAndFuseProducer) {
// 1. Tile and fuse of gemm with fill producer and bias-add consumer.
addPatternForTileAndFuse(context, patterns, "fusion", {10, 20});
More information about the Mlir-commits
mailing list