[Mlir-commits] [mlir] d9f645f - [mlir] Allow specifying benefit for C func ptr style patterns.
Jacques Pienaar
llvmlistbot at llvm.org
Thu Dec 22 09:10:22 PST 2022
Author: Chenguang Wang
Date: 2022-12-22T09:10:15-08:00
New Revision: d9f645fe5081fccbe59560989cdf8ea4535946fc
URL: https://github.com/llvm/llvm-project/commit/d9f645fe5081fccbe59560989cdf8ea4535946fc
DIFF: https://github.com/llvm/llvm-project/commit/d9f645fe5081fccbe59560989cdf8ea4535946fc.diff
LOG: [mlir] Allow specifying benefit for C func ptr style patterns.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D139234
Added:
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/unittests/IR/PatternMatchTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 0743e377150e1..3ee533c89c537 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -1638,12 +1638,15 @@ class RewritePatternSet {
// Add a matchAndRewrite style pattern represented as a C function pointer.
template <typename OpType>
- RewritePatternSet &add(LogicalResult (*implFn)(OpType,
- PatternRewriter &rewriter)) {
+ RewritePatternSet &
+ add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
+ PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) {
struct FnPattern final : public OpRewritePattern<OpType> {
FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
- MLIRContext *context)
- : OpRewritePattern<OpType>(context), implFn(implFn) {}
+ MLIRContext *context, PatternBenefit benefit,
+ ArrayRef<StringRef> generatedNames)
+ : OpRewritePattern<OpType>(context, benefit, generatedNames),
+ implFn(implFn) {}
LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
@@ -1653,7 +1656,8 @@ class RewritePatternSet {
private:
LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
};
- add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
+ add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit,
+ generatedNames));
return *this;
}
diff --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp
index 6454f05fbf2fa..3a58d5c1634d9 100644
--- a/mlir/unittests/IR/PatternMatchTest.cpp
+++ b/mlir/unittests/IR/PatternMatchTest.cpp
@@ -28,3 +28,22 @@ TEST(OpRewritePatternTest, GetGeneratedNames) {
ASSERT_EQ(ops.front().getStringRef(), test::OpB::getOperationName());
}
} // end anonymous namespace
+
+namespace {
+LogicalResult anOpRewritePatternFunc(test::OpA op, PatternRewriter &rewriter) {
+ return failure();
+}
+TEST(AnOpRewritePatternTest, PatternFuncAttributes) {
+ MLIRContext context;
+ RewritePatternSet patterns(&context);
+
+ patterns.add(anOpRewritePatternFunc, /*benefit=*/3,
+ /*generatedNames=*/{test::OpB::getOperationName()});
+ ASSERT_EQ(patterns.getNativePatterns().size(), 1);
+ auto &pattern = patterns.getNativePatterns().front();
+ ASSERT_EQ(pattern->getBenefit(), 3);
+ ASSERT_EQ(pattern->getGeneratedOps().size(), 1);
+ ASSERT_EQ(pattern->getGeneratedOps().front().getStringRef(),
+ test::OpB::getOperationName());
+}
+} // end anonymous namespace
More information about the Mlir-commits
mailing list