[Mlir-commits] [mlir] 67760d7 - [mlir] GreedyPatternRewriteDriver: Make classes single-use

Matthias Springer llvmlistbot at llvm.org
Fri Jan 27 02:02:53 PST 2023


Author: Matthias Springer
Date: 2023-01-27T10:55:16+01:00
New Revision: 67760d7e315ff90198bccfd4b0a3934f7a30e6ce

URL: https://github.com/llvm/llvm-project/commit/67760d7e315ff90198bccfd4b0a3934f7a30e6ce
DIFF: https://github.com/llvm/llvm-project/commit/67760d7e315ff90198bccfd4b0a3934f7a30e6ce.diff

LOG: [mlir] GreedyPatternRewriteDriver: Make classes single-use

Less mutable state, more `const`. This is to address a concern about complexity of state in D140304.

Differential Revision: https://reviews.llvm.org/D141949

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index aba7a7fd08b9..a5ddd9138873 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -39,10 +39,11 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 public:
   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
                                       const FrozenRewritePatternSet &patterns,
-                                      const GreedyRewriteConfig &config);
+                                      const GreedyRewriteConfig &config,
+                                      const DenseSet<Region *> &scope);
 
   /// Simplify the operations within the given regions.
-  bool simplify(MutableArrayRef<Region> regions);
+  bool simplify(MutableArrayRef<Region> regions) &&;
 
   /// Add the given operation and its ancestors to the worklist.
   void addToWorklist(Operation *op);
@@ -100,12 +101,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 
 protected:
   /// Configuration information for how to simplify.
-  GreedyRewriteConfig config;
+  const GreedyRewriteConfig config;
 
-  /// Only ops within this scope are simplified. This is set at the beginning
-  /// of `simplify()` and `simplifyLocally()` to the current scope the rewriter
-  /// operates on.
-  DenseSet<Region *> scope;
+  /// Only ops within this scope are simplified.
+  const DenseSet<Region *> scope;
 
 private:
 #ifndef NDEBUG
@@ -117,19 +116,16 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 
 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
-    const GreedyRewriteConfig &config)
-    : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
+    const GreedyRewriteConfig &config, const DenseSet<Region *> &scope)
+    : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config),
+      scope(scope) {
   worklist.reserve(64);
 
   // Apply a simple cost model based solely on pattern benefit.
   matcher.applyDefaultCostModel();
 }
 
-bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
-  scope.clear();
-  for (Region &r : regions)
-    scope.insert(&r);
-
+bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
 #ifndef NDEBUG
   const char *logLineComment =
       "//===-------------------------------------------===//\n";
@@ -449,9 +445,15 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
   assert(llvm::all_of(regions, regionIsIsolated) &&
          "patterns can only be applied to operations IsolatedFromAbove");
 
+  // Limit ops on the worklist to this scope.
+  DenseSet<Region *> scope;
+  for (Region &r : regions)
+    scope.insert(&r);
+
   // Start the pattern driver.
-  GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
-  bool converged = driver.simplify(regions);
+  GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config,
+                                    scope);
+  bool converged = std::move(driver).simplify(regions);
   LLVM_DEBUG(if (!converged) {
     llvm::dbgs() << "The pattern rewrite did not converge after scanning "
                  << config.maxIterations << " times\n";
@@ -472,11 +474,12 @@ namespace {
 /// ops are not considered.
 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 public:
-  explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
-                                       const FrozenRewritePatternSet &patterns,
-                                       GreedyRewriteStrictness strictMode)
-      : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
-        strictMode(strictMode) {}
+  explicit MultiOpPatternRewriteDriver(
+      MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+      const DenseSet<Region *> &scope, GreedyRewriteStrictness strictMode,
+      llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
+      : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope),
+        strictMode(strictMode), survivingOps(survivingOps) {}
 
   /// Performs the specified rewrites on `ops` while also trying to fold these
   /// ops. `strictMode` controls which other ops are simplified. Only ops
@@ -486,11 +489,9 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
   /// Note that ops in `ops` could be erased as a result of folding, becoming
   /// dead, or via pattern rewrites. The return value indicates convergence.
   ///
-  /// All `ops` that survived the rewrite are stored in `surviving`.
-  LogicalResult
-  simplifyLocally(ArrayRef<Operation *> ops, bool *changed = nullptr,
-                  llvm::SmallDenseSet<Operation *, 4> *surviving = nullptr,
-                  Region *scope = nullptr);
+  /// All erased ops are stored in `erased`.
+  LogicalResult simplifyLocally(ArrayRef<Operation *> op,
+                                bool *changed = nullptr) &&;
 
 protected:
   void addSingleOpToWorklist(Operation *op) override {
@@ -516,7 +517,7 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 
   /// `strictMode` control which ops are added to the worklist during
   /// simplification.
-  GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
+  const GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
 
   /// The list of ops we are restricting our rewrites to. These include the
   /// supplied set of ops as well as new ops created while rewriting those ops
@@ -527,17 +528,15 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
   /// An optional set of ops that survived the rewrite. This set is populated
   /// at the beginning of `simplifyLocally` with the inititally provided list
   /// of ops.
-  llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr;
+  llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
 };
 
 } // namespace
 
-LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
-    ArrayRef<Operation *> ops, bool *changed,
-    llvm::SmallDenseSet<Operation *, 4> *surviving, Region *scope) {
-  auto cleanup = llvm::make_scope_exit([&]() { survivingOps = nullptr; });
-  if (surviving) {
-    survivingOps = surviving;
+LogicalResult
+MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
+                                             bool *changed) && {
+  if (survivingOps) {
     survivingOps->clear();
     survivingOps->insert(ops.begin(), ops.end());
   }
@@ -547,10 +546,6 @@ LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
     strictModeFilteredOps.insert(ops.begin(), ops.end());
   }
 
-  assert(scope && "scope is mandatory");
-  this->scope.clear();
-  this->scope.insert(scope);
-
   if (changed)
     *changed = false;
   worklist.clear();
@@ -684,11 +679,13 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
   }
 
   // Start the pattern driver.
-  MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
-                                     strictMode);
   llvm::SmallDenseSet<Operation *, 4> surviving;
-  LogicalResult converged = driver.simplifyLocally(
-      ops, changed, allErased ? &surviving : nullptr, /*scope=*/scope);
+  DenseSet<Region *> scopeSet;
+  scopeSet.insert(scope);
+  MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
+                                     scopeSet, strictMode,
+                                     allErased ? &surviving : nullptr);
+  LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
   if (allErased)
     *allErased = surviving.empty();
   LLVM_DEBUG(if (failed(converged)) {


        


More information about the Mlir-commits mailing list