[Mlir-commits] [mlir] d9f645f - [mlir] Allow specifying benefit for C func ptr style patterns.

Jacques Pienaar llvmlistbot at llvm.org
Thu Dec 22 09:10:22 PST 2022


Author: Chenguang Wang
Date: 2022-12-22T09:10:15-08:00
New Revision: d9f645fe5081fccbe59560989cdf8ea4535946fc

URL: https://github.com/llvm/llvm-project/commit/d9f645fe5081fccbe59560989cdf8ea4535946fc
DIFF: https://github.com/llvm/llvm-project/commit/d9f645fe5081fccbe59560989cdf8ea4535946fc.diff

LOG: [mlir] Allow specifying benefit for C func ptr style patterns.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D139234

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/unittests/IR/PatternMatchTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 0743e377150e1..3ee533c89c537 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -1638,12 +1638,15 @@ class RewritePatternSet {
 
   // Add a matchAndRewrite style pattern represented as a C function pointer.
   template <typename OpType>
-  RewritePatternSet &add(LogicalResult (*implFn)(OpType,
-                                                 PatternRewriter &rewriter)) {
+  RewritePatternSet &
+  add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
+      PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) {
     struct FnPattern final : public OpRewritePattern<OpType> {
       FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
-                MLIRContext *context)
-          : OpRewritePattern<OpType>(context), implFn(implFn) {}
+                MLIRContext *context, PatternBenefit benefit,
+                ArrayRef<StringRef> generatedNames)
+          : OpRewritePattern<OpType>(context, benefit, generatedNames),
+            implFn(implFn) {}
 
       LogicalResult matchAndRewrite(OpType op,
                                     PatternRewriter &rewriter) const override {
@@ -1653,7 +1656,8 @@ class RewritePatternSet {
     private:
       LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
     };
-    add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
+    add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit,
+                                    generatedNames));
     return *this;
   }
 

diff  --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp
index 6454f05fbf2fa..3a58d5c1634d9 100644
--- a/mlir/unittests/IR/PatternMatchTest.cpp
+++ b/mlir/unittests/IR/PatternMatchTest.cpp
@@ -28,3 +28,22 @@ TEST(OpRewritePatternTest, GetGeneratedNames) {
   ASSERT_EQ(ops.front().getStringRef(), test::OpB::getOperationName());
 }
 } // end anonymous namespace
+
+namespace {
+LogicalResult anOpRewritePatternFunc(test::OpA op, PatternRewriter &rewriter) {
+  return failure();
+}
+TEST(AnOpRewritePatternTest, PatternFuncAttributes) {
+  MLIRContext context;
+  RewritePatternSet patterns(&context);
+
+  patterns.add(anOpRewritePatternFunc, /*benefit=*/3,
+               /*generatedNames=*/{test::OpB::getOperationName()});
+  ASSERT_EQ(patterns.getNativePatterns().size(), 1);
+  auto &pattern = patterns.getNativePatterns().front();
+  ASSERT_EQ(pattern->getBenefit(), 3);
+  ASSERT_EQ(pattern->getGeneratedOps().size(), 1);
+  ASSERT_EQ(pattern->getGeneratedOps().front().getStringRef(),
+            test::OpB::getOperationName());
+}
+} // end anonymous namespace


        


More information about the Mlir-commits mailing list