[Mlir-commits] [mlir] 2257e4a - [mlir] Allow derived rewrite patterns to define a non-virtual `initialize` hook

River Riddle llvmlistbot at llvm.org
Tue May 18 14:40:52 PDT 2021


Author: River Riddle
Date: 2021-05-18T14:40:32-07:00
New Revision: 2257e4a70e4aabe7255161f3a54922d7dcf1c059

URL: https://github.com/llvm/llvm-project/commit/2257e4a70e4aabe7255161f3a54922d7dcf1c059
DIFF: https://github.com/llvm/llvm-project/commit/2257e4a70e4aabe7255161f3a54922d7dcf1c059.diff

LOG: [mlir] Allow derived rewrite patterns to define a non-virtual `initialize` hook

This is a hook that allows for providing custom initialization of the pattern, e.g. if it has bounded recursion, setting the debug name, etc., without needing to define a custom constructor. A non-virtual hook was chosen to avoid polluting the vtable with code that we really just want to be inlined when constructing the pattern. The alternative to this would be to just define a constructor for each pattern, this unfortunately creates a lot of otherwise unnecessary boiler plate for a lot of patterns and a hook provides a much simpler/cleaner interface for the very common case.

Differential Revision: https://reviews.llvm.org/D102440

Added: 
    

Modified: 
    mlir/docs/PatternRewriter.md
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index c4727e620b06..2e4d57e8f6e7 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -125,6 +125,37 @@ can signal this by calling `setHasBoundedRewriteRecursion` when initializing the
 pattern. This will signal to the pattern driver that recursive application of
 this pattern may happen, and the pattern is equipped to safely handle it.
 
+### Initialization
+
+Several pieces of pattern state require explicit initialization by the pattern,
+for example setting `setHasBoundedRewriteRecursion` if a pattern safely handles
+recursive application. This pattern state can be initialized either in the
+constructor of the pattern or via the utility `initialize` hook. Using the
+`initialize` hook removes the need to redefine pattern constructors just to
+inject additional pattern state initialization. An example is shown below:
+
+```c++
+class MyPattern : public RewritePattern {
+public:
+  /// Inherit the constructors from RewritePattern.
+  using RewritePattern::RewritePattern;
+
+  /// Initialize the pattern.
+  void initialize() {
+    /// Signal that this pattern safely handles recursive application.
+    setHasBoundedRewriteRecursion();
+  }
+
+  // ...
+};
+```
+
+### Construction
+
+Constructing a RewritePattern should be performed by using the static
+`RewritePattern::create<T>` utility method. This method ensures that the pattern
+is properly initialized and prepared for insertion into a `RewritePatternSet`.
+
 ## Pattern Rewriter
 
 A `PatternRewriter` is a special class that allows for a pattern to communicate

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 944b6e8a1d1a..621cd85ca360 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -255,10 +255,43 @@ class RewritePattern : public Pattern {
     return failure();
   }
 
+  /// This method provides a convenient interface for creating and initializing
+  /// derived rewrite patterns of the given type `T`.
+  template <typename T, typename... Args>
+  static std::unique_ptr<T> create(Args &&... args) {
+    std::unique_ptr<T> pattern =
+        std::make_unique<T>(std::forward<Args>(args)...);
+    initializePattern<T>(*pattern);
+
+    // Set a default debug name if one wasn't provided.
+    if (pattern->getDebugName().empty())
+      pattern->setDebugName(llvm::getTypeName<T>());
+    return pattern;
+  }
+
 protected:
   /// Inherit the base constructors from `Pattern`.
   using Pattern::Pattern;
 
+private:
+  /// Trait to check if T provides a `getOperationName` method.
+  template <typename T, typename... Args>
+  using has_initialize = decltype(std::declval<T>().initialize());
+  template <typename T>
+  using detect_has_initialize = llvm::is_detected<has_initialize, T>;
+
+  /// Initialize the derived pattern by calling its `initialize` method.
+  template <typename T>
+  static std::enable_if_t<detect_has_initialize<T>::value>
+  initializePattern(T &pattern) {
+    pattern.initialize();
+  }
+  /// Empty derived pattern initializer for patterns that do not have an
+  /// initialize method.
+  template <typename T>
+  static std::enable_if_t<!detect_has_initialize<T>::value>
+  initializePattern(T &) {}
+
   /// An anchor for the virtual table.
   virtual void anchor();
 };
@@ -992,13 +1025,8 @@ class RewritePatternSet {
   template <typename T, typename... Args>
   std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
   addImpl(Args &&... args) {
-    auto pattern = std::make_unique<T>(std::forward<Args>(args)...);
-
-    // Pattern can potentially set name in ctor. Preserve old name if present.
-    if (pattern->getDebugName().empty())
-      pattern->setDebugName(llvm::getTypeName<T>());
-
-    nativePatterns.emplace_back(std::move(pattern));
+    nativePatterns.emplace_back(
+        RewritePattern::create<T>(std::forward<Args>(args)...));
   }
   template <typename T, typename... Args>
   std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9db34d7411e9..566559158385 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -935,8 +935,9 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
 class VectorInsertStridedSliceOpSameRankRewritePattern
     : public OpRewritePattern<InsertStridedSliceOp> {
 public:
-  VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
-      : OpRewritePattern<InsertStridedSliceOp>(ctx) {
+  using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+
+  void initialize() {
     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
     // bounded as the rank is strictly decreasing.
     setHasBoundedRewriteRecursion();
@@ -1330,8 +1331,9 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
 class VectorExtractStridedSliceOpConversion
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
-  VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
-      : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
+  using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+
+  void initialize() {
     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
     // is bounded as the rank is strictly decreasing.
     setHasBoundedRewriteRecursion();

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b319257dbe06..cb5546874d68 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -473,8 +473,9 @@ struct TestNonRootReplacement : public RewritePattern {
 /// bounded recursion.
 struct TestBoundedRecursiveRewrite
     : public OpRewritePattern<TestRecursiveRewriteOp> {
-  TestBoundedRecursiveRewrite(MLIRContext *ctx)
-      : OpRewritePattern<TestRecursiveRewriteOp>(ctx) {
+  using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
+
+  void initialize() {
     // The conversion target handles bounding the recursion of this pattern.
     setHasBoundedRewriteRecursion();
   }


        


More information about the Mlir-commits mailing list