[Mlir-commits] [mlir] 9cd7e88 - [mlir][Linalg] NFC - Modernize more transformation patterns.
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jan 6 14:41:26 PST 2022
Author: Nicolas Vasilache
Date: 2022-01-06T17:40:23-05:00
New Revision: 9cd7e880fd22ff3f8058e81dcd59c2f03074053d
URL: https://github.com/llvm/llvm-project/commit/9cd7e880fd22ff3f8058e81dcd59c2f03074053d
DIFF: https://github.com/llvm/llvm-project/commit/9cd7e880fd22ff3f8058e81dcd59c2f03074053d.diff
LOG: [mlir][Linalg] NFC - Modernize more transformation patterns.
Differential Revision: https://reviews.llvm.org/D116763
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4b55caed849d..72726f7b006b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -442,11 +442,19 @@ struct LinalgTransformationFilter {
filters.push_back(f);
return *this;
}
+
template <typename... OpTypes>
LinalgTransformationFilter &addOpFilter() {
return addFilter(
[](Operation *op) { return success(isa<OpTypes...>(op)); });
}
+
+ LinalgTransformationFilter &addOpNameFilter(StringRef opName) {
+ return addFilter([opName](Operation *op) {
+ return success(op->getName().getStringRef() == opName);
+ });
+ }
+
LinalgTransformationFilter &setMatchByDefault() {
matchByDefault = true;
return *this;
@@ -607,7 +615,7 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
/// See `tiling` for more details.
// TODO: TiledOpInterface
struct LinalgTilingPattern : public OpInterfaceRewritePattern<LinalgOp> {
- /// Construct a generic pattern applied to all LinalgOp that verify `f`.
+ /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgTilingPattern(
MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter f = LinalgTransformationFilter(),
@@ -643,20 +651,29 @@ struct LinalgTilingPattern : public OpInterfaceRewritePattern<LinalgOp> {
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `padding` for more details.
struct LinalgPaddingPattern : public OpInterfaceRewritePattern<LinalgOp> {
- // Entry point to match any LinalgOp OpInterface.
+ /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgPaddingPattern(
MLIRContext *context,
LinalgPaddingOptions options = LinalgPaddingOptions(),
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
- // Entry point to match a specific LinalgOp.
+
+ /// Construct a pattern specifically applied to `opName`.
LinalgPaddingPattern(
StringRef opName, MLIRContext *context,
LinalgPaddingOptions options = LinalgPaddingOptions(),
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
- LogicalResult matchAndRewrite(LinalgOp,
- PatternRewriter &rewriter) const override;
+
+ /// `matchAndRewrite` implementation that returns the significant transformed
+ /// pieces of IR.
+ FailureOr<LinalgOp> returningMatchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const;
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(op, rewriter);
+ }
private:
/// LinalgTransformMarker handles special attribute manipulations.
@@ -679,7 +696,7 @@ struct LinalgBaseTileAndFusePattern : public RewritePattern {
StringRef opName, MLIRContext *context,
const LinalgDependenceGraph &dependenceGraph,
LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
LinalgTransformationFilter originalOpMarker =
LinalgTransformationFilter(),
@@ -711,14 +728,14 @@ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
LinalgTileAndFusePattern(
MLIRContext *context, const LinalgDependenceGraph &dependenceGraph,
LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
LinalgTransformationFilter originalOpMarker =
LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseTileAndFusePattern(
OpTy::getOperationName(), context, dependenceGraph, tilingOptions,
- fusionOptions, filter, fusedOpMarker, originalOpMarker, benefit) {}
+ fusionOptions, f, fusedOpMarker, originalOpMarker, benefit) {}
};
///
@@ -731,13 +748,13 @@ struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
// Entry point to match any LinalgOp.
LinalgTileAndFuseTensorOpsPattern(
MLIRContext *context, LinalgTilingAndFusionOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
// Entry point to match a specific LinalgOp.
LinalgTileAndFuseTensorOpsPattern(
StringRef opName, MLIRContext *context,
LinalgTilingAndFusionOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
@@ -757,12 +774,22 @@ struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
/// See `interchange` for more details.
struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ /// GenericOp-specific constructor with an optional `filter`.
GenericOpInterchangePattern(
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override;
+
+ /// `matchAndRewrite` implementation that returns the significant transformed
+ /// pieces of IR.
+ FailureOr<GenericOp>
+ returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const;
+
+ LogicalResult matchAndRewrite(GenericOp op,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(op, rewriter);
+ }
private:
/// LinalgTransformMarker handles special attribute manipulations.
@@ -777,19 +804,29 @@ struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
/// Apply the `generalization` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `generalization` for more details.
-struct LinalgGeneralizationPattern : public RewritePattern {
- // Entry point to match any LinalgOp OpInterface.
+struct LinalgGeneralizationPattern
+ : public OpInterfaceRewritePattern<LinalgOp> {
+ /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgGeneralizationPattern(
MLIRContext *context,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
- // Entry point to match a specific Linalg op.
+
+ /// Construct a pattern specifically applied to `opName`.
LinalgGeneralizationPattern(
StringRef opName, MLIRContext *context,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override;
+
+ /// `matchAndRewrite` implementation that returns the significant transformed
+ /// pieces of IR.
+ FailureOr<GenericOp>
+ returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const;
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(op, rewriter);
+ }
private:
/// LinalgTransformMarker handles special attribute manipulations.
@@ -806,13 +843,13 @@ struct LinalgBasePromotionPattern : public RewritePattern {
/// Entry point to match any LinalgOp OpInterface.
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
LinalgBasePromotionPattern(
- MLIRContext *context, LinalgTransformationFilter filter,
+ MLIRContext *context, LinalgTransformationFilter f,
LinalgPromotionOptions options = LinalgPromotionOptions(),
PatternBenefit benefit = 1);
/// Entry point to match a specific Linalg op.
LinalgBasePromotionPattern(
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
@@ -832,16 +869,16 @@ struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
template <typename ConcreateOpTy = OpTy>
LinalgPromotionPattern(
MLIRContext *context, LinalgPromotionOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBasePromotionPattern(OpTy::getOperationName(), context, options,
- filter, benefit) {}
+ f, benefit) {}
/// This constructor is available to anyone.
LinalgPromotionPattern(
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
- : LinalgBasePromotionPattern(opName, context, options, filter, benefit) {}
+ : LinalgBasePromotionPattern(opName, context, options, f, benefit) {}
};
///
@@ -852,39 +889,28 @@ struct LinalgVectorizationOptions {};
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `vectorizeLinalgOp` for more details.
-struct LinalgBaseVectorizationPattern : public RewritePattern {
- /// MatchAnyOpTag-based constructor with a mandatory `filter`.
- LinalgBaseVectorizationPattern(MLIRContext *context,
- LinalgTransformationFilter filter,
- PatternBenefit benefit = 1);
- /// Name-based constructor with an optional `filter`.
- LinalgBaseVectorizationPattern(
+struct LinalgVectorizationPattern : public OpInterfaceRewritePattern<LinalgOp> {
+ /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
+ LinalgVectorizationPattern(
+ MLIRContext *context,
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
+ LinalgVectorizationOptions options = LinalgVectorizationOptions(),
+ PatternBenefit benefit = 1);
+
+ /// Construct a pattern specifically applied to `opName`.
+ LinalgVectorizationPattern(
StringRef opName, MLIRContext *context,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgVectorizationOptions options = LinalgVectorizationOptions(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
- LogicalResult matchAndRewrite(Operation *op,
+
+ LogicalResult matchAndRewrite(LinalgOp linalgOp,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
-};
-
-struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
- /// These constructors are available to anyone.
- /// MatchAnyOpTag-based constructor with a mandatory `filter`.
- LinalgVectorizationPattern(
- MLIRContext *context, LinalgTransformationFilter filter,
- LinalgVectorizationOptions options = LinalgVectorizationOptions(),
- PatternBenefit benefit = 1)
- : LinalgBaseVectorizationPattern(context, filter, benefit) {}
- /// Name-based constructor with an optional `filter`.
- LinalgVectorizationPattern(
- StringRef opName, MLIRContext *context,
- LinalgVectorizationOptions options = LinalgVectorizationOptions(),
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : LinalgBaseVectorizationPattern(opName, context, filter, benefit) {}
+ LinalgVectorizationOptions options;
};
//===----------------------------------------------------------------------===//
@@ -1008,48 +1034,6 @@ struct LinalgVectorLoweringOptions {
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
-/// Trait to check if T provides a `getOperationName` method.
-template <typename T, typename... Args>
-using has_get_operation_name = decltype(T::getOperationName());
-template <typename T>
-using detect_has_get_operation_name =
- llvm::is_detected<has_get_operation_name, T>;
-
-/// SFINAE helper for single C++ op with a `getOperationName` method.
-template <
- typename OpType,
- typename = std::enable_if_t<detect_has_get_operation_name<OpType>::value>,
- typename = void>
-void insertVectorizationPatternImpl(RewritePatternSet &patternList,
- linalg::LinalgVectorizationOptions options,
- linalg::LinalgTransformationFilter f) {
- patternList.add<linalg::LinalgVectorizationPattern>(
- OpType::getOperationName(), patternList.getContext(), options, f);
-}
-
-/// SFINAE helper for single C++ class without a `getOperationName` method (e.g.
-/// an OpInterface).
-template <typename OpType, typename = std::enable_if_t<
- !detect_has_get_operation_name<OpType>::value>>
-void insertVectorizationPatternImpl(RewritePatternSet &patternList,
- linalg::LinalgVectorizationOptions options,
- linalg::LinalgTransformationFilter f) {
- patternList.add<linalg::LinalgVectorizationPattern>(
- patternList.getContext(), f.addOpFilter<OpType>(), options);
-}
-
-/// Variadic helper function to insert vectorization patterns for C++ ops.
-template <typename... OpTypes>
-void insertVectorizationPatterns(RewritePatternSet &patternList,
- linalg::LinalgVectorizationOptions options,
- linalg::LinalgTransformationFilter f =
- linalg::LinalgTransformationFilter()) {
- // FIXME: In c++17 this can be simplified by using 'fold expressions'.
- (void)std::initializer_list<int>{
- 0,
- (insertVectorizationPatternImpl<OpTypes>(patternList, options, f), 0)...};
-}
-
///
/// Linalg lowering patterns.
///
@@ -1067,10 +1051,10 @@ template <typename OpTy>
struct LinalgLoweringPattern : public RewritePattern {
LinalgLoweringPattern(
MLIRContext *context, LinalgLoweringType loweringType,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: RewritePattern(OpTy::getOperationName(), benefit, context),
- filter(filter), loweringType(loweringType) {}
+ filter(std::move(f)), loweringType(loweringType) {}
// TODO: Move implementation to .cpp once named ops are auto-generated.
LogicalResult matchAndRewrite(Operation *op,
@@ -1352,6 +1336,29 @@ struct ExtractSliceOfPadTensorSwapPattern
//===----------------------------------------------------------------------===//
// Helper classes for type list expansion.
//===----------------------------------------------------------------------===//
+template <typename... OpTypes>
+class VectorizationPatterns;
+
+template <>
+class VectorizationPatterns<> {
+public:
+ static void insert(RewritePatternSet &patterns,
+ const LinalgVectorizationOptions &options,
+ const LinalgTransformationFilter &f) {}
+};
+
+template <typename OpTy, typename... OpTypes>
+class VectorizationPatterns<OpTy, OpTypes...> {
+public:
+ static void insert(RewritePatternSet &patterns,
+ const LinalgVectorizationOptions &options,
+ const LinalgTransformationFilter &f) {
+ patterns.add<LinalgVectorizationPattern>(OpTy::getOperationName(),
+ patterns.getContext(), options, f);
+ VectorizationPatterns<OpTypes...>::insert(patterns, options, f);
+ }
+};
+
template <typename... OpTypes>
class TilingPatterns;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index c1482f44b4cd..22f95653701c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -357,11 +357,11 @@ mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
StringRef opName, MLIRContext *context,
const LinalgDependenceGraph &dependenceGraph,
LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
- LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
+ LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker,
LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
: RewritePattern(opName, benefit, context, {}),
dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)),
- fusionOptions(std::move(fusionOptions)), filter(std::move(filter)),
+ fusionOptions(std::move(fusionOptions)), filter(std::move(f)),
fusedOpMarker(std::move(fusedOpMarker)),
originalOpMarker(std::move(originalOpMarker)) {}
@@ -462,11 +462,7 @@ mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter f, PatternBenefit benefit)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
- filter(std::move(f)), options(std::move(options)) {
- this->filter.addFilter([opName](Operation *op) {
- return success(op->getName().getStringRef() == opName);
- });
-}
+ filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
FailureOr<TiledLinalgOp>
mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
@@ -496,21 +492,18 @@ mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
/// Linalg padding pattern.
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
MLIRContext *context, LinalgPaddingOptions options,
- LinalgTransformationFilter filter, PatternBenefit benefit)
+ LinalgTransformationFilter f, PatternBenefit benefit)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
- filter(std::move(filter)), options(std::move(options)) {}
+ filter(std::move(f)), options(std::move(options)) {}
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
- LinalgTransformationFilter filter, PatternBenefit benefit)
+ LinalgTransformationFilter f, PatternBenefit benefit)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
- filter(std::move(filter)), options(std::move(options)) {
- this->filter.addFilter([opName](Operation *op) {
- return success(op->getName().getStringRef() == opName);
- });
-}
+ filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
-LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
+FailureOr<LinalgOp>
+mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
LinalgOp linalgOp, PatternRewriter &rewriter) const {
if (!linalgOp.hasTensorSemantics())
return failure();
@@ -549,24 +542,24 @@ LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
// Replace the original operation to pad.
rewriter.replaceOp(linalgOp, newResults.getValue());
filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
- return success();
+ return paddedOp;
}
/// Linalg tile and fuse tensor ops pattern.
mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
LinalgTilingAndFusionOptions options,
- LinalgTransformationFilter filter,
+ LinalgTransformationFilter f,
PatternBenefit benefit)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
- filter(std::move(filter)), options(std::move(options)) {}
+ filter(std::move(f)), options(std::move(options)) {}
mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
LinalgTilingAndFusionOptions options,
- LinalgTransformationFilter filter,
+ LinalgTransformationFilter f,
PatternBenefit benefit)
- : RewritePattern(opName, benefit, context), filter(std::move(filter)),
+ : RewritePattern(opName, benefit, context), filter(std::move(f)),
options(std::move(options)) {}
LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
@@ -624,11 +617,12 @@ LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
/// Linalg generic interchange pattern.
mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
- LinalgTransformationFilter filter, PatternBenefit benefit)
- : OpRewritePattern(context, benefit), filter(std::move(filter)),
+ LinalgTransformationFilter f, PatternBenefit benefit)
+ : OpRewritePattern(context, benefit), filter(std::move(f)),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
-LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
+FailureOr<GenericOp>
+mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
GenericOp genericOp, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, genericOp)))
return failure();
@@ -645,41 +639,38 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
/// Linalg generalization pattern.
mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
- MLIRContext *context, LinalgTransformationFilter filter,
- PatternBenefit benefit)
- : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
- filter(std::move(filter)) {}
+ MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+ filter(std::move(f)) {}
mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
- StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
+ StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
PatternBenefit benefit)
- : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)) {}
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+ filter(f.addOpNameFilter(opName)) {}
-LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
- Operation *op, PatternRewriter &rewriter) const {
- // TODO: Interface pattern.
- LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- if (!linalgOp)
- return failure();
- if (failed(filter.checkAndNotify(rewriter, op)))
+FailureOr<GenericOp>
+mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite(
+ LinalgOp linalgOp, PatternRewriter &rewriter) const {
+ if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
if (failed(genericOp))
return failure();
filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
- return success();
+ return genericOp;
}
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
- MLIRContext *context, LinalgTransformationFilter filter,
+ MLIRContext *context, LinalgTransformationFilter f,
LinalgPromotionOptions options, PatternBenefit benefit)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
- filter(std::move(filter)), options(std::move(options)) {}
+ filter(std::move(f)), options(std::move(options)) {}
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
- LinalgTransformationFilter filter, PatternBenefit benefit)
- : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)),
+ LinalgTransformationFilter f, PatternBenefit benefit)
+ : RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
options(std::move(options)) {}
LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
@@ -704,24 +695,21 @@ LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
return success();
}
-mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
- MLIRContext *context, LinalgTransformationFilter filter,
- PatternBenefit benefit)
- : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
- filter(std::move(filter)) {}
+mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
+ MLIRContext *context, LinalgTransformationFilter f,
+ LinalgVectorizationOptions options, PatternBenefit benefit)
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+ filter(std::move(f)) {}
-mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
- StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
- PatternBenefit benefit)
- : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)) {}
+mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
+ StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
+ LinalgTransformationFilter f, PatternBenefit benefit)
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+ filter(f.addOpNameFilter(opName)) {}
-LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
- Operation *op, PatternRewriter &rewriter) const {
- // TODO: Interface-based rewrite.
- LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- if (!linalgOp)
- return failure();
- if (failed(filter.checkAndNotify(rewriter, op)))
+LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
+ LinalgOp linalgOp, PatternRewriter &rewriter) const {
+ if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
return vectorize(rewriter, linalgOp);
}
@@ -947,10 +935,10 @@ struct DownscaleSizeOneWindowed2DConvolution final
: public OpRewritePattern<Conv2DNhwcHwcfOp> {
DownscaleSizeOneWindowed2DConvolution(
MLIRContext *context,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
- filter(std::move(filter)) {}
+ filter(std::move(f)) {}
LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
PatternRewriter &rewriter) const override {
@@ -1033,10 +1021,10 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final
: public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
DownscaleDepthwiseConv2DNhwcHwcOp(
MLIRContext *context,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
- filter(std::move(filter)) {}
+ filter(std::move(f)) {}
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 0c8ab052a88c..aad40c672c38 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -300,8 +300,7 @@ static void fillL1TilingAndMatmulToVectorPatterns(
MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
patternsVector.back().add<LinalgVectorizationPattern>(
- ctx, LinalgTransformationFilter().addFilter(
- [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
+ ctx, LinalgTransformationFilter().addOpFilter<FillOp, CopyOp>());
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list