[Mlir-commits] [mlir] 4a66160 - [mlir][Linalg] NFC - Modernize APIs and get rid of unnecessary tiling paterns.
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jan 6 13:27:40 PST 2022
Author: Nicolas Vasilache
Date: 2022-01-06T16:27:35-05:00
New Revision: 4a661602ef2db22272cbb39bdb179996dbfa54b1
URL: https://github.com/llvm/llvm-project/commit/4a661602ef2db22272cbb39bdb179996dbfa54b1
DIFF: https://github.com/llvm/llvm-project/commit/4a661602ef2db22272cbb39bdb179996dbfa54b1.diff
LOG: [mlir][Linalg] NFC - Modernize APIs and get rid of unnecessary tiling paterns.
Tiling patterns can be reduced to a single pattern by using interface-based patterns.
Differential Revision: https://reviews.llvm.org/D116733
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7592094410632..4b55caed849d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -169,9 +169,14 @@ struct TiledLinalgOp {
SmallVector<Operation *, 8> loops;
SmallVector<Value, 4> tensorResults;
};
-FailureOr<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
+FailureOr<TiledLinalgOp> tileLinalgOp(RewriterBase &b, LinalgOp op,
const LinalgTilingOptions &options);
+/// Peel the loops of a TiledLinalgOp.
+void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
+ ArrayRef<int64_t> peeledLoops,
+ LinalgTilingLoopType loopType);
+
/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
/// proceeds as follows:
/// - Find outer parallel loops in these ops that can be fused.
@@ -594,24 +599,35 @@ struct LinalgTilingOptions {
RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
-/// Base pattern that applies the tiling transformation specified by `options`.
-/// Abort and return failure in 2 cases:
-/// 1. if the tiling specification is invalid and tiling fails to occur.
-/// 2. if tiling occurs but `options.paddingValueComputationFunction` is set
-/// and some operand shape cannot be bounded statically.
-struct LinalgBaseTilingPattern : public RewritePattern {
- // Entry point to match any LinalgOp OpInterface.
- LinalgBaseTilingPattern(
+///
+/// Linalg tiling pattern.
+///
+/// Apply the `tiling` transformation as a pattern.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `tiling` for more details.
+// TODO: TiledOpInterface
+struct LinalgTilingPattern : public OpInterfaceRewritePattern<LinalgOp> {
+ /// Construct a generic pattern applied to all LinalgOp that verify `f`.
+ LinalgTilingPattern(
MLIRContext *context, LinalgTilingOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
- // Entry point to match a specific Linalg op.
- LinalgBaseTilingPattern(
+
+ /// Construct a pattern specifically applied to `opName`.
+ LinalgTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
- LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
- TiledLinalgOp &result) const;
+
+ /// `matchAndRewrite` implementation that returns the significant transformed
+ /// pieces of IR.
+ FailureOr<TiledLinalgOp>
+ returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const;
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(op, rewriter);
+ }
private:
/// LinalgTransformMarker handles special attribute manipulations.
@@ -620,68 +636,6 @@ struct LinalgBaseTilingPattern : public RewritePattern {
LinalgTilingOptions options;
};
-template <typename OpTy>
-struct LinalgTilingPattern : public LinalgBaseTilingPattern {
- /// SFINAE: This constructor can only trigger for concrete ops that have a
- /// static `getOperationName` method.
- template <typename ConcreateOpTy = OpTy>
- LinalgTilingPattern(
- MLIRContext *context, LinalgTilingOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : LinalgBaseTilingPattern(ConcreateOpTy::getOperationName(), context,
- options, filter, benefit) {}
-
- /// This constructor is available to anyone.
- LinalgTilingPattern(
- StringRef opName, MLIRContext *context, LinalgTilingOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : LinalgBaseTilingPattern(opName, context, options, filter, benefit) {}
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- TiledLinalgOp tiledLinalgOp;
- if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
- tiledLinalgOp)))
- return failure();
- if (tiledLinalgOp.tensorResults.empty())
- rewriter.eraseOp(op);
- else
- rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
- return success();
- }
-};
-
-struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
- /// Entry point to match any LinalgOp OpInterface.
- /// MatchAnyOpTag-based constructor with a mandatory `filter`.
- LinalgGenericTilingPattern(
- MLIRContext *context, LinalgTransformationFilter filter,
- LinalgTilingOptions options = LinalgTilingOptions(),
- PatternBenefit benefit = 1)
- : LinalgBaseTilingPattern(context, options, filter, benefit) {}
- /// Entry point to match a specific Linalg op.
- LinalgGenericTilingPattern(
- StringRef opName, MLIRContext *context, LinalgTilingOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : LinalgBaseTilingPattern(opName, context, options, filter, benefit) {}
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- TiledLinalgOp tiledLinalgOp;
- if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
- tiledLinalgOp)))
- return failure();
- if (tiledLinalgOp.tensorResults.empty())
- rewriter.eraseOp(op);
- else
- rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
- return success();
- }
-};
-
///
/// Linalg padding pattern.
///
@@ -1395,6 +1349,32 @@ struct ExtractSliceOfPadTensorSwapPattern
PatternRewriter &rewriter) const override;
};
+//===----------------------------------------------------------------------===//
+// Helper classes for type list expansion.
+//===----------------------------------------------------------------------===//
+template <typename... OpTypes>
+class TilingPatterns;
+
+template <>
+class TilingPatterns<> {
+public:
+ static void insert(RewritePatternSet &patterns,
+ const LinalgTilingOptions &options,
+ const LinalgTransformationFilter &f) {}
+};
+
+template <typename OpTy, typename... OpTypes>
+class TilingPatterns<OpTy, OpTypes...> {
+public:
+ static void insert(RewritePatternSet &patterns,
+ const LinalgTilingOptions &options,
+ const LinalgTransformationFilter &f) {
+ patterns.add<LinalgTilingPattern>(OpTy::getOperationName(),
+ patterns.getContext(), options, f);
+ TilingPatterns<OpTypes...>::insert(patterns, options, f);
+ }
+};
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index f426af01d8722..bc53a719a4741 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -784,7 +784,9 @@ tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizeVector,
tileSizes[i] = zero;
LinalgTilingOptions tileFusedLoopsOptions = options;
tileFusedLoopsOptions.setTileSizes(tileSizes);
- return tileLinalgOp(b, op, tileFusedLoopsOptions);
+ // TODO: Propagate RewriterBase everywhere.
+ IRRewriter rewriter(b);
+ return tileLinalgOp(rewriter, op, tileFusedLoopsOptions);
}
/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 6bdcc192e27aa..eb1415dabde2d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -283,10 +283,14 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
tileInterchange.begin(), tileInterchange.end()))
.setTileSizes(tileSizes)
.setLoopType(LinalgTilingLoopType::Loops);
- Optional<TiledLinalgOp> tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions);
+
+ // TODO: Propagate RewriterBase everywhere.
+ IRRewriter rewriter(b);
+ FailureOr<TiledLinalgOp> tiledRootOp =
+ tileLinalgOp(rewriter, rootOp, tilingOptions);
// Exit if tiling the root operation fails.
- if (!tiledRootOp.hasValue())
+ if (failed(tiledRootOp))
return failure();
// Replace all uses of the root operation if it has been tiled before. All
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 136f38feedf35..859f3f8521b13 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -1,4 +1,4 @@
-//===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===//
+//===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -93,14 +93,13 @@ struct LinalgStrategyTilePass
if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
return;
- RewritePatternSet tilingPattern(funcOp.getContext());
- if (!anchorOpName.empty()) {
- tilingPattern.add<LinalgGenericTilingPattern>(
- anchorOpName, funcOp.getContext(), options, filter);
- } else {
- tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter,
- options);
- }
+ MLIRContext *ctx = funcOp.getContext();
+ RewritePatternSet tilingPattern(ctx);
+ if (!anchorOpName.empty())
+ tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options,
+ filter);
+ else
+ tilingPattern.add<LinalgTilingPattern>(ctx, options, filter);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index cb2987973ea51..89ca83375c0f6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -51,7 +51,7 @@ using LoopIndexToRangeIndexMap = DenseMap<int, int>;
// a map from loop indices of the LinalgOp to the corresponding non-empty range
// indices of newly created loops.
static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
-makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
+makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
ValueRange allShapeSizes, ValueRange allTileSizes) {
assert(allTileSizes.size() == map.getNumResults());
// Apply `map` to get shape sizes in loop order.
@@ -129,7 +129,7 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
// TODO: Investigate whether mixing implicit and explicit indices
// does not lead to losing information.
static void
-transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
+transformIndexOps(RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
for (auto &en : enumerate(allIvs)) {
@@ -144,7 +144,7 @@ transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
// Insert a tile `source` into the destination tensor `dest`. The position at
// which the tile is inserted (as well as size of tile) is taken from a given
// ExtractSliceOp `sliceOp`.
-static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
+static Value insertSliceIntoTensor(RewriterBase &b, Location loc,
tensor::ExtractSliceOp sliceOp, Value source,
Value dest) {
return b.create<tensor::InsertSliceOp>(
@@ -155,7 +155,7 @@ static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
template <typename LoopTy>
static FailureOr<TiledLinalgOp>
-tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
+tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes,
const LinalgTilingOptions &options) {
auto nLoops = op.getNumLoops();
// Initial tile sizes may be too big, only take the first nLoops.
@@ -216,7 +216,7 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
LinalgOp res = op;
SmallVector<Value, 4> ivs, tensorResults;
auto tiledLoopBodyBuilder =
- [&](OpBuilder &b, Location loc, ValueRange localIvs,
+ [&](OpBuilder &builder, Location loc, ValueRange localIvs,
ValueRange operandValuesToUse) -> scf::ValueVector {
ivs.assign(localIvs.begin(), localIvs.end());
@@ -255,9 +255,12 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
// TODO: use an interface/adaptor to avoid leaking position in
// `tiledOperands`.
Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
+ // TODO: Propagate RewriterBase everywhere.
+ IRRewriter rewriter(b);
if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
- tensorResults.push_back(insertSliceIntoTensor(
- b, loc, sliceOp, res->getResult(resultIdx), sliceOp.source()));
+ tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp,
+ res->getResult(resultIdx),
+ sliceOp.source()));
} else {
tensorResults.push_back(res->getResult(resultIdx));
}
@@ -299,7 +302,7 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
template <typename LoopTy>
FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
- OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) {
+ RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
@@ -321,7 +324,7 @@ FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
}
FailureOr<TiledLinalgOp>
-mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
+mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op,
const LinalgTilingOptions &options) {
switch (options.loopType) {
case LinalgTilingLoopType::Loops:
@@ -338,7 +341,7 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
/// Generate a loop nest around a given PadTensorOp (for tiling). `newPadOp`
/// and `loopNest` are output parameters that return the new (tiled) PadTensorOp
/// and the loop nest.
-static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
+static LogicalResult tilePadTensorOp(RewriterBase &builder, PadTensorOp op,
PadTensorOp &newPadOp, LoopNest &loopNest,
const LinalgTilingOptions &options) {
Location loc = op.getLoc();
@@ -384,8 +387,10 @@ static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>();
assert(sliceOp && "expected ExtractSliceOp");
// Insert the tile into the output tensor.
+ // TODO: Propagate RewriterBase everywhere.
+ IRRewriter rewriter(b);
Value yieldValue =
- insertSliceIntoTensor(b, loc, sliceOp, sliceOp, iterArgs[0]);
+ insertSliceIntoTensor(rewriter, loc, sliceOp, sliceOp, iterArgs[0]);
return scf::ValueVector({yieldValue});
});
return success();
@@ -434,31 +439,6 @@ class CanonicalizationPatternList<OpTy, OpTypes...> {
CanonicalizationPatternList<OpTypes...>::insert(patterns);
}
};
-
-/// Helper classes for type list expansion.
-template <typename... OpTypes>
-class RewritePatternList;
-
-template <>
-class RewritePatternList<> {
-public:
- static void insert(RewritePatternSet &patterns,
- const LinalgTilingOptions &options) {}
-};
-
-template <typename OpTy, typename... OpTypes>
-class RewritePatternList<OpTy, OpTypes...> {
-public:
- static void insert(RewritePatternSet &patterns,
- const LinalgTilingOptions &options) {
- auto *ctx = patterns.getContext();
- patterns.add<LinalgTilingPattern<OpTy>>(
- ctx, options,
- LinalgTransformationFilter(ArrayRef<StringAttr>{},
- StringAttr::get(ctx, "tiled")));
- RewritePatternList<OpTypes...>::insert(patterns, options);
- }
-};
} // namespace
RewritePatternSet
@@ -500,11 +480,14 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
/// Populate the given list with patterns that apply Linalg tiling.
static void insertTilingPatterns(RewritePatternSet &patterns,
const LinalgTilingOptions &options) {
- RewritePatternList<GenericOp,
+ auto *ctx = patterns.getContext();
+ LinalgTransformationFilter f(ArrayRef<StringAttr>{},
+ StringAttr::get(ctx, "tiled"));
+ TilingPatterns<GenericOp,
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
- >::insert(patterns, options);
- patterns.add<PadTensorOpTilingPattern>(patterns.getContext(), options);
+ >::insert(patterns, options, f);
+ patterns.add<PadTensorOpTilingPattern>(ctx, options);
}
static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 177a2abda6e7d..c1482f44b4cd0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1,4 +1,4 @@
-//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
+//===- Transforms.cpp - Linalg transformations as patterns ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -284,19 +284,6 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
return paddedSubviewResults;
}
-/// Linalg base tiling pattern.
-mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
- StringRef opName, MLIRContext *context, LinalgTilingOptions options,
- LinalgTransformationFilter filter, PatternBenefit benefit)
- : RewritePattern(opName, benefit, context), filter(std::move(filter)),
- options(std::move(options)) {}
-
-mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
- MLIRContext *context, LinalgTilingOptions options,
- LinalgTransformationFilter filter, PatternBenefit benefit)
- : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
- filter(std::move(filter)), options(std::move(options)) {}
-
/// Try to peel a loop `op` and return the new result.
// TODO: Add support for scf.parallel and affine.for loops.
static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
@@ -325,14 +312,15 @@ static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
}
/// Peel loops after tiling.
-static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
- const LinalgTilingOptions &options) {
- for (int64_t loop : options.peeledLoops) {
+void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
+ ArrayRef<int64_t> peeledLoops,
+ LinalgTilingLoopType loopType) {
+ for (int64_t loop : peeledLoops) {
assert(loop < static_cast<int64_t>(res.loops.size()) &&
"requested peeling of non-existing loop");
SmallVector<Value, 4> loopResults;
Operation *loopOp = res.loops[loop];
- if (options.loopType == LinalgTilingLoopType::TiledLoops) {
+ if (loopType == LinalgTilingLoopType::TiledLoops) {
assert(llvm::all_of(
res.loops,
[&](Operation *op) { return op == res.loops.front(); }) &&
@@ -352,28 +340,6 @@ static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
}
}
-LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
- Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
- LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- if (!linalgOp)
- return failure();
- if (failed(filter.checkAndNotify(rewriter, linalgOp)))
- return failure();
-
- Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
-
- if (!res)
- return failure();
- // Clear filter to stop recursive pattern application.
- filter.replaceLinalgTransformationFilter(rewriter, res->op);
-
- // Peel loops.
- peelLoops(rewriter, *res, options);
-
- result = *res;
- return success();
-}
-
static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
if (tiledOp.loops.empty())
return tiledOp.op.getOperation()->getResults();
@@ -459,9 +425,9 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
})) {
LinalgTilingOptions unfusedTilingOptions = tilingOptions;
unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
- Optional<TiledLinalgOp> unfusedTiledOp =
+ FailureOr<TiledLinalgOp> unfusedTiledOp =
tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
- if (!unfusedTiledOp)
+ if (failed(unfusedTiledOp))
return failure();
rewriter.replaceOp(tiledAndFusedOps->op,
getTiledOpResult(unfusedTiledOp.getValue()));
@@ -485,6 +451,48 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
return success();
}
+/// Linalg tiling pattern.
+mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
+ MLIRContext *context, LinalgTilingOptions options,
+ LinalgTransformationFilter f, PatternBenefit benefit)
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+ filter(std::move(f)), options(std::move(options)) {}
+
+mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
+ StringRef opName, MLIRContext *context, LinalgTilingOptions options,
+ LinalgTransformationFilter f, PatternBenefit benefit)
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+ filter(std::move(f)), options(std::move(options)) {
+ this->filter.addFilter([opName](Operation *op) {
+ return success(op->getName().getStringRef() == opName);
+ });
+}
+
+FailureOr<TiledLinalgOp>
+mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
+ LinalgOp op, PatternRewriter &rewriter) const {
+ if (failed(filter.checkAndNotify(rewriter, op)))
+ return failure();
+
+ FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
+ if (failed(res))
+ return failure();
+
+ // Clear filter to stop recursive pattern application.
+ // This must be done here to properly propagate to peeling branches.
+ filter.replaceLinalgTransformationFilter(rewriter, res->op);
+
+ // Peel the loops of the TiledLinalgOp.
+ peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
+
+ if (res->tensorResults.empty())
+ rewriter.eraseOp(op);
+ else
+ rewriter.replaceOp(op, res->tensorResults);
+
+ return res;
+}
+
/// Linalg padding pattern.
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
MLIRContext *context, LinalgPaddingOptions options,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4a597f64d72ff..14593800b16f1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1178,8 +1178,9 @@ static void populateVectorizationPatterns(
constexpr static StringRef kTiledMarker = "TILED";
constexpr static StringRef kPromotedMarker = "PROMOTED";
- tilingPatterns.add<LinalgTilingPattern<ConvOp>>(
- context, LinalgTilingOptions().setTileSizes(tileSizes),
+ tilingPatterns.add<LinalgTilingPattern>(
+ ConvOp::getOperationName(), context,
+ LinalgTilingOptions().setTileSizes(tileSizes),
LinalgTransformationFilter(ArrayRef<StringAttr>{},
StringAttr::get(kTiledMarker, context)));
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index d4119f26c8197..0c8ab052a88c1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -138,32 +138,36 @@ static void applyPatterns(FuncOp funcOp) {
//===--------------------------------------------------------------------===//
// Linalg tiling patterns.
//===--------------------------------------------------------------------===//
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), ctx,
+ LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
StringAttr::get(ctx, "L3")));
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), ctx,
+ LinalgTilingOptions().setTileSizes({200, 300, 400}),
LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
StringAttr::get(ctx, "L2")));
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), ctx,
+ LinalgTilingOptions().setTileSizes({20, 30, 40}),
LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
StringAttr::get(ctx, "L1")));
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), ctx,
+ LinalgTilingOptions().setTileSizes({2, 3, 4}),
LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
StringAttr::get(ctx, "REG")));
- patterns.add<LinalgTilingPattern<MatvecOp>>(
- ctx,
+ patterns.add<LinalgTilingPattern>(
+ MatvecOp::getOperationName(), ctx,
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
LinalgTilingLoopType::ParallelLoops),
LinalgTransformationFilter(ArrayRef<StringAttr>{},
StringAttr::get(ctx, "L1")));
- patterns.add<LinalgTilingPattern<DotOp>>(
- ctx, LinalgTilingOptions().setTileSizes(8000),
+ patterns.add<LinalgTilingPattern>(
+ DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
LinalgTransformationFilter(
ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
StringAttr::get(ctx, "L3"),
@@ -173,32 +177,34 @@ static void applyPatterns(FuncOp funcOp) {
//===--------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
//===--------------------------------------------------------------------===//
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- ctx,
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), ctx,
LinalgTilingOptions()
.setTileSizes({2000, 3000, 4000})
.setInterchange({1, 2, 0}),
LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
StringAttr::get(ctx, "L2__with_perm__")));
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- ctx,
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), ctx,
LinalgTilingOptions()
.setTileSizes({200, 300, 400})
.setInterchange({1, 0, 2}),
LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
StringAttr::get(ctx, "L1__with_perm__")));
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), ctx,
+ LinalgTilingOptions().setTileSizes({20, 30, 40}),
LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
StringAttr::get(ctx, "REG__with_perm__")));
- patterns.add<LinalgTilingPattern<MatvecOp>>(
- ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
+ patterns.add<LinalgTilingPattern>(
+ MatvecOp::getOperationName(), ctx,
+ LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
StringAttr::get(ctx, "L1__with_perm__")));
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- ctx,
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), ctx,
LinalgTilingOptions()
.setTileSizes({16, 8, 4})
.setInterchange({1, 2, 0})
@@ -274,8 +280,8 @@ static void fillL1TilingAndMatmulToVectorPatterns(
SmallVectorImpl<RewritePatternSet> &patternsVector) {
MLIRContext *ctx = funcOp.getContext();
patternsVector.emplace_back(
- ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
- ctx,
+ ctx, std::make_unique<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), ctx,
LinalgTilingOptions()
.setTileSizes({8, 12, 16})
.setInterchange({1, 0, 2}),
@@ -339,8 +345,9 @@ static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
static void fillPromotionCallBackPatterns(MLIRContext *ctx,
RewritePatternSet &patterns) {
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), ctx,
+ LinalgTilingOptions().setTileSizes({16, 16, 16}),
LinalgTransformationFilter(StringAttr::get(ctx, "START"),
StringAttr::get(ctx, "PROMOTE")));
patterns.add<LinalgPromotionPattern<MatmulOp>>(
@@ -382,8 +389,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
2, DistributionMethod::CyclicNumProcsEqNumIters);
cyclicNprocsEqNiters.procInfo =
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- context,
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
@@ -399,8 +406,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
2, DistributionMethod::CyclicNumProcsGeNumIters);
cyclicNprocsGeNiters.procInfo =
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- context,
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
@@ -416,8 +423,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::Cyclic);
cyclicNprocsDefault.procInfo =
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- context,
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
@@ -433,8 +440,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::CyclicNumProcsEqNumIters,
DistributionMethod::CyclicNumProcsGeNumIters};
cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- context,
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
@@ -450,8 +457,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::CyclicNumProcsGeNumIters,
DistributionMethod::Cyclic};
cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- context,
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
@@ -468,8 +475,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::CyclicNumProcsEqNumIters};
cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- context,
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
@@ -485,8 +492,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::Cyclic);
cyclicNprocsEqNiters.procInfo =
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.add<LinalgTilingPattern<MatmulOp>>(
- context,
+ patterns.add<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::Loops)
@@ -507,8 +514,8 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
} else if (testMatmulToVectorPatterns2dTiling) {
stage1Patterns.emplace_back(
- ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
- ctx,
+ ctx, std::make_unique<LinalgTilingPattern>(
+ MatmulOp::getOperationName(), ctx,
LinalgTilingOptions()
.setTileSizes({768, 264, 768})
.setInterchange({1, 2, 0}),
@@ -589,10 +596,9 @@ static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
} else {
linalgTilingOptions.setTileSizes(tileSizes);
}
- tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
- linalg::LinalgTilingPattern<linalg::GenericOp>>(
- context, linalgTilingOptions,
- linalg::LinalgTransformationFilter(StringAttr::get(context, "tile")));
+ linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
+ TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
+ tilingPattern, linalgTilingOptions, f);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
}
More information about the Mlir-commits
mailing list