[Mlir-commits] [mlir] e195e6b - [mlir] GreedyPatternRewriteDriver: Enqueue ancestors in MultiOpPatternRewriteDriver

Matthias Springer llvmlistbot at llvm.org
Fri Jan 27 01:39:19 PST 2023


Author: Matthias Springer
Date: 2023-01-27T10:39:10+01:00
New Revision: e195e6bad6706230a4b5fd4b5cc13de1f16f25cc

URL: https://github.com/llvm/llvm-project/commit/e195e6bad6706230a4b5fd4b5cc13de1f16f25cc
DIFF: https://github.com/llvm/llvm-project/commit/e195e6bad6706230a4b5fd4b5cc13de1f16f25cc.diff

LOG: [mlir] GreedyPatternRewriteDriver: Enqueue ancestors in MultiOpPatternRewriteDriver

The `GreedyPatternRewriteDriver` was extended to enqueue ancestors in D140304. With this change, `MultiOpPatternRewriteDriver` behaves the same way.

Note: `MultiOpPatternRewriteDriver` now also has a scope that limits how far we go when checking ancestors. By default, this is the first common region of all given ops.

Differential Revision: https://reviews.llvm.org/D141945

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index f72dbb7ff2986..dce47834547e8 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -112,6 +112,12 @@ LogicalResult applyOpPatternsAndFold(Operation *op,
 /// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops are
 ///   simplified. All other ops are excluded.
 ///
+/// In addition to strictness, a region scope can be specified. Only ops within
+/// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`,
+/// where only ops within the given regions are simplified. If no scope is
+/// specified, it is assumed to be the first common enclosing region of the
+/// given ops.
+///
 /// Note that ops in `ops` could be erased as result of folding, becoming dead,
 /// or via pattern rewrites. If more far reaching simplification is desired,
 /// applyPatternsAndFoldGreedily should be used.
@@ -123,7 +129,8 @@ LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
                                      const FrozenRewritePatternSet &patterns,
                                      GreedyRewriteStrictness strictMode,
                                      bool *changed = nullptr,
-                                     bool *allErased = nullptr);
+                                     bool *allErased = nullptr,
+                                     Region *scope = nullptr);
 
 } // namespace mlir
 

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 2b3a796dee93f..ead229dacae80 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Rewrite/PatternApplicator.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/CommandLine.h"
@@ -43,10 +44,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   /// Simplify the operations within the given regions.
   bool simplify(MutableArrayRef<Region> regions);
 
-  /// Add the given operation to the worklist. Parent ops may or may not be
-  /// added to the worklist, depending on the type of rewrite driver. By
-  /// default, parent ops are added.
-  virtual void addToWorklist(Operation *op);
+  /// Add the given operation and its ancestors to the worklist.
+  void addToWorklist(Operation *op);
 
   /// Pop the next operation from the worklist.
   Operation *popFromWorklist();
@@ -60,7 +59,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 
 protected:
   /// Add the given operation to the worklist.
-  void addSingleOpToWorklist(Operation *op);
+  virtual void addSingleOpToWorklist(Operation *op);
 
   // Implement the hook for inserting operations, and make sure that newly
   // inserted ops are added to the worklist for processing.
@@ -103,11 +102,12 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   /// Configuration information for how to simplify.
   GreedyRewriteConfig config;
 
-private:
   /// Only ops within this scope are simplified. This is set at the beginning
-  /// of `simplify()` to the current scope the rewriter operates on.
+  /// of `simplify()` and `simplifyLocally()` to the current scope the rewriter
+  /// operates on.
   DenseSet<Region *> scope;
 
+private:
 #ifndef NDEBUG
   /// A logger used to emit information during the application process.
   llvm::ScopedPrinter logger{llvm::dbgs()};
@@ -126,6 +126,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
 }
 
 bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
+  scope.clear();
   for (Region &r : regions)
     scope.insert(&r);
 
@@ -581,7 +582,9 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
         strictMode(strictMode) {}
 
   /// Performs the specified rewrites on `ops` while also trying to fold these
-  /// ops. `strictMode` controls which other ops are simplified.
+  /// ops. `strictMode` controls which other ops are simplified. Only ops
+  /// within the given scope region are added to the worklist. If no scope is
+  /// specified, it assumed to be closest common region of all `ops`.
   ///
   /// Note that ops in `ops` could be erased as a result of folding, becoming
   /// dead, or via pattern rewrites. The return value indicates convergence.
@@ -589,9 +592,11 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
   /// All `ops` that survived the rewrite are stored in `surviving`.
   LogicalResult
   simplifyLocally(ArrayRef<Operation *> ops, bool *changed = nullptr,
-                  llvm::SmallDenseSet<Operation *, 4> *surviving = nullptr);
+                  llvm::SmallDenseSet<Operation *, 4> *surviving = nullptr,
+                  Region *scope = nullptr);
 
-  void addToWorklist(Operation *op) override {
+protected:
+  void addSingleOpToWorklist(Operation *op) override {
     if (strictMode == GreedyRewriteStrictness::AnyOp ||
         strictModeFilteredOps.contains(op))
       GreedyPatternRewriteDriver::addSingleOpToWorklist(op);
@@ -632,7 +637,7 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 
 LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
     ArrayRef<Operation *> ops, bool *changed,
-    llvm::SmallDenseSet<Operation *, 4> *surviving) {
+    llvm::SmallDenseSet<Operation *, 4> *surviving, Region *scope) {
   auto cleanup = llvm::make_scope_exit([&]() { survivingOps = nullptr; });
   if (surviving) {
     survivingOps = surviving;
@@ -645,12 +650,16 @@ LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
     strictModeFilteredOps.insert(ops.begin(), ops.end());
   }
 
+  assert(scope && "scope is mandatory");
+  this->scope.clear();
+  this->scope.insert(scope);
+
   if (changed)
     *changed = false;
   worklist.clear();
   worklistMap.clear();
   for (Operation *op : ops)
-    addToWorklist(op);
+    addSingleOpToWorklist(op);
 
   // These are scratch vectors used in the folding loop below.
   SmallVector<Value, 8> originalOperands, resultValues;
@@ -742,9 +751,37 @@ LogicalResult mlir::applyOpPatternsAndFold(
   return converged;
 }
 
-LogicalResult mlir::applyOpPatternsAndFold(
-    ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
-    GreedyRewriteStrictness strictMode, bool *changed, bool *allErased) {
+/// Find the region that is the closest common ancestor of all given ops.
+static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
+  assert(!ops.empty() && "expected at least one op");
+  // Fast path in case there is only one op.
+  if (ops.size() == 1)
+    return ops.front()->getParentRegion();
+
+  Region *region = ops.front()->getParentRegion();
+  ops = ops.drop_front();
+  int sz = ops.size();
+  llvm::BitVector remainingOps(sz, true);
+  do {
+    int pos = -1;
+    // Iterate over all remaining ops.
+    while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
+      // Is this op contained in `region`?
+      if (region->findAncestorOpInRegion(*ops[pos]))
+        remainingOps.reset(pos);
+    }
+    if (remainingOps.none())
+      break;
+  } while ((region = region->getParentRegion()));
+  assert(region && "could not find common parent region");
+  return region;
+}
+
+LogicalResult
+mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
+                             const FrozenRewritePatternSet &patterns,
+                             GreedyRewriteStrictness strictMode, bool *changed,
+                             bool *allErased, Region *scope) {
   if (ops.empty()) {
     if (changed)
       *changed = false;
@@ -753,12 +790,25 @@ LogicalResult mlir::applyOpPatternsAndFold(
     return success();
   }
 
+  if (!scope) {
+    // Compute scope if none was provided.
+    scope = findCommonAncestor(ops);
+  } else {
+    // If a scope was provided, make sure that all ops are in scope.
+#ifndef NDEBUG
+    bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
+      return static_cast<bool>(scope->findAncestorOpInRegion(*op));
+    });
+    assert(allOpsInScope && "ops must be within the specified scope");
+#endif // NDEBUG
+  }
+
   // Start the pattern driver.
   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
                                      strictMode);
   llvm::SmallDenseSet<Operation *, 4> surviving;
-  LogicalResult converged =
-      driver.simplifyLocally(ops, changed, allErased ? &surviving : nullptr);
+  LogicalResult converged = driver.simplifyLocally(
+      ops, changed, allErased ? &surviving : nullptr, /*scope=*/scope);
   if (allErased)
     *allErased = surviving.empty();
   return converged;


        


More information about the Mlir-commits mailing list