[Mlir-commits] [mlir] 2c4a56c - [mlir][Linalg] NFC - Modernize padding pattern
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jan 6 05:59:39 PST 2022
Author: Nicolas Vasilache
Date: 2022-01-06T08:59:35-05:00
New Revision: 2c4a56c4183f4f01c0b0959acec6972fddd79b7d
URL: https://github.com/llvm/llvm-project/commit/2c4a56c4183f4f01c0b0959acec6972fddd79b7d
DIFF: https://github.com/llvm/llvm-project/commit/2c4a56c4183f4f01c0b0959acec6972fddd79b7d.diff
LOG: [mlir][Linalg] NFC - Modernize padding pattern
Differential Revision: https://reviews.llvm.org/D116739
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index c1185c0a8ff70..7592094410632 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -688,7 +688,7 @@ struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
/// Apply the `padding` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `padding` for more details.
-struct LinalgPaddingPattern : public RewritePattern {
+struct LinalgPaddingPattern : public OpInterfaceRewritePattern<LinalgOp> {
// Entry point to match any LinalgOp OpInterface.
LinalgPaddingPattern(
MLIRContext *context,
@@ -701,7 +701,7 @@ struct LinalgPaddingPattern : public RewritePattern {
LinalgPaddingOptions options = LinalgPaddingOptions(),
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
- LogicalResult matchAndRewrite(Operation *op,
+ LogicalResult matchAndRewrite(LinalgOp,
PatternRewriter &rewriter) const override;
private:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8b9c7bd2f60d2..177a2abda6e7d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -489,23 +489,24 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
MLIRContext *context, LinalgPaddingOptions options,
LinalgTransformationFilter filter, PatternBenefit benefit)
- : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
filter(std::move(filter)), options(std::move(options)) {}
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
LinalgTransformationFilter filter, PatternBenefit benefit)
- : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)),
- options(std::move(options)) {}
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+ filter(std::move(filter)), options(std::move(options)) {
+ this->filter.addFilter([opName](Operation *op) {
+ return success(op->getName().getStringRef() == opName);
+ });
+}
LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
- Operation *op, PatternRewriter &rewriter) const {
- LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- if (!linalgOp)
- return failure();
+ LinalgOp linalgOp, PatternRewriter &rewriter) const {
if (!linalgOp.hasTensorSemantics())
return failure();
- if (failed(filter.checkAndNotify(rewriter, op)))
+ if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
// Pad the operation.
@@ -538,7 +539,7 @@ LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
}
// Replace the original operation to pad.
- rewriter.replaceOp(op, newResults.getValue());
+ rewriter.replaceOp(linalgOp, newResults.getValue());
filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
return success();
}
More information about the Mlir-commits
mailing list