[Mlir-commits] [mlir] 6af2c4c - [mlir] Change the internal representation of FrozenRewritePatternList to use shared_ptr

River Riddle llvmlistbot at llvm.org
Mon Dec 14 12:40:00 PST 2020


Author: River Riddle
Date: 2020-12-14T12:32:44-08:00
New Revision: 6af2c4ca9bdb37e56cfda8dae4f6c3c6ca21b8d7

URL: https://github.com/llvm/llvm-project/commit/6af2c4ca9bdb37e56cfda8dae4f6c3c6ca21b8d7
DIFF: https://github.com/llvm/llvm-project/commit/6af2c4ca9bdb37e56cfda8dae4f6c3c6ca21b8d7.diff

LOG: [mlir] Change the internal representation of FrozenRewritePatternList to use shared_ptr

This will allow for caching pattern lists across multiple pass instances, such as when multithreading. This is an extremely important invariant for PDL patterns, which are compiled at runtime when the FrozenRewritePatternList is built.

Differential Revision: https://reviews.llvm.org/D93146

Added: 
    

Modified: 
    mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
    mlir/lib/Rewrite/FrozenRewritePatternList.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
index c2335b9dd5a1..0e583aab3dc4 100644
--- a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
+++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
@@ -18,34 +18,52 @@ class PDLByteCode;
 
 /// This class represents a frozen set of patterns that can be processed by a
 /// pattern applicator. This class is designed to enable caching pattern lists
-/// such that they need not be continuously recomputed.
+/// such that they need not be continuously recomputed. Note that all copies of
+/// this class share the same compiled pattern list, allowing for a reduction in
+/// the number of duplicated patterns that need to be created.
 class FrozenRewritePatternList {
   using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
 
 public:
   /// Freeze the patterns held in `patterns`, and take ownership.
+  FrozenRewritePatternList();
   FrozenRewritePatternList(OwningRewritePatternList &&patterns);
-  FrozenRewritePatternList(FrozenRewritePatternList &&patterns);
+  FrozenRewritePatternList(FrozenRewritePatternList &&patterns) = default;
+  FrozenRewritePatternList(const FrozenRewritePatternList &patterns) = default;
+  FrozenRewritePatternList &
+  operator=(const FrozenRewritePatternList &patterns) = default;
+  FrozenRewritePatternList &
+  operator=(FrozenRewritePatternList &&patterns) = default;
   ~FrozenRewritePatternList();
 
   /// Return the native patterns held by this list.
   iterator_range<llvm::pointee_iterator<NativePatternListT::const_iterator>>
   getNativePatterns() const {
+    const NativePatternListT &nativePatterns = impl->nativePatterns;
     return llvm::make_pointee_range(nativePatterns);
   }
 
   /// Return the compiled PDL bytecode held by this list. Returns null if
   /// there are no PDL patterns within the list.
   const detail::PDLByteCode *getPDLByteCode() const {
-    return pdlByteCode.get();
+    return impl->pdlByteCode.get();
   }
 
 private:
-  /// The set of.
-  std::vector<std::unique_ptr<RewritePattern>> nativePatterns;
+  /// The internal implementation of the frozen pattern list.
+  struct Impl {
+    /// The set of native C++ rewrite patterns.
+    NativePatternListT nativePatterns;
 
-  /// The bytecode containing the compiled PDL patterns.
-  std::unique_ptr<detail::PDLByteCode> pdlByteCode;
+    /// The bytecode containing the compiled PDL patterns.
+    std::unique_ptr<detail::PDLByteCode> pdlByteCode;
+  };
+
+  /// A pointer to the internal pattern list. This uses a shared_ptr to avoid
+  /// the need to compile the same pattern list multiple times. For example,
+  /// during multi-threaded pass execution, all copies of a pass can share the
+  /// same pattern list.
+  std::shared_ptr<Impl> impl;
 };
 
 } // end namespace mlir

diff  --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
index 60f6dcea88f2..40d7fcde8f33 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
@@ -50,12 +50,16 @@ static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
 // FrozenRewritePatternList
 //===----------------------------------------------------------------------===//
 
+FrozenRewritePatternList::FrozenRewritePatternList()
+    : impl(std::make_shared<Impl>()) {}
+
 FrozenRewritePatternList::FrozenRewritePatternList(
     OwningRewritePatternList &&patterns)
-    : nativePatterns(std::move(patterns.getNativePatterns())) {
-  PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
+    : impl(std::make_shared<Impl>()) {
+  impl->nativePatterns = std::move(patterns.getNativePatterns());
 
   // Generate the bytecode for the PDL patterns if any were provided.
+  PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
   ModuleOp pdlModule = pdlPatterns.getModule();
   if (!pdlModule)
     return;
@@ -64,14 +68,9 @@ FrozenRewritePatternList::FrozenRewritePatternList(
         "failed to lower PDL pattern module to the PDL Interpreter");
 
   // Generate the pdl bytecode.
-  pdlByteCode = std::make_unique<detail::PDLByteCode>(
+  impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
       pdlModule, pdlPatterns.takeConstraintFunctions(),
       pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions());
 }
 
-FrozenRewritePatternList::FrozenRewritePatternList(
-    FrozenRewritePatternList &&patterns)
-    : nativePatterns(std::move(patterns.nativePatterns)),
-      pdlByteCode(std::move(patterns.pdlByteCode)) {}
-
 FrozenRewritePatternList::~FrozenRewritePatternList() {}


        


More information about the Mlir-commits mailing list