[Mlir-commits] [mlir] 724a0e2 - [mlir] GreedyPatternRewriteDriver: Ignore scope when rewriting top-level ops

Matthias Springer llvmlistbot at llvm.org
Fri Feb 3 00:57:05 PST 2023


Author: Matthias Springer
Date: 2023-02-03T09:56:55+01:00
New Revision: 724a0e2c2d7a5724dd81b00db470ba4bb8b616ca

URL: https://github.com/llvm/llvm-project/commit/724a0e2c2d7a5724dd81b00db470ba4bb8b616ca
DIFF: https://github.com/llvm/llvm-project/commit/724a0e2c2d7a5724dd81b00db470ba4bb8b616ca.diff

LOG: [mlir] GreedyPatternRewriteDriver: Ignore scope when rewriting top-level ops

Top-level ModuleOps cannot be transformed with the GreedyPatternRewriteDriver since D141945 because they do not have an enclosing region that could be used as a scope. Make the scope optional inside GreedyPatternRewriteDriver, so that top-level ops can be processed when they are on the initial list of ops.

Note: This does not allow users to bypass the scoping mechanism by setting `config.scope = nullptr`.

Fixes #60462.

Differential Revision: https://reviews.llvm.org/D143151

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index d8c17c67357b0..423221dd80da0 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -61,7 +61,8 @@ class GreedyRewriteConfig {
   static constexpr int64_t kNoLimit = -1;
 
   /// Only ops within the scope are added to the worklist. If no scope is
-  /// specified, the closest enclosing region is used as a scope.
+  /// specified, the closest enclosing region around the initial list of ops
+  /// is used as a scope.
   Region *scope = nullptr;
 
   /// Strict mode can restrict the ops that are added to the worklist during

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 4c5868aead3f1..997bdc6a1c49f 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -124,7 +124,6 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
     const GreedyRewriteConfig &config)
     : PatternRewriter(ctx), folder(ctx), config(config), matcher(patterns) {
-  assert(config.scope && "scope is not specified");
   worklist.reserve(64);
 
   // Apply a simple cost model based solely on pattern benefit.
@@ -266,19 +265,19 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
   // Gather potential ancestors while looking for a "scope" parent region.
   SmallVector<Operation *, 8> ancestors;
-  ancestors.push_back(op);
-  while (Region *region = op->getParentRegion()) {
-      if (config.scope == region) {
-        // All gathered ops are in fact ancestors.
-        for (Operation *op : ancestors)
-          addSingleOpToWorklist(op);
-        break;
-      }
-    op = region->getParentOp();
-    if (!op)
-      break;
+  Region *region = nullptr;
+  do {
     ancestors.push_back(op);
-  }
+    region = op->getParentRegion();
+    if (config.scope == region) {
+      // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
+      for (Operation *op : ancestors)
+        addSingleOpToWorklist(op);
+      return;
+    }
+    if (region == nullptr)
+      return;
+  } while ((op = region->getParentOp()));
 }
 
 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
@@ -556,6 +555,9 @@ LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
 }
 
 /// Find the region that is the closest common ancestor of all given ops.
+///
+/// Note: This function returns `nullptr` if there is a top-level op among the
+/// given list of ops.
 static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
   assert(!ops.empty() && "expected at least one op");
   // Fast path in case there is only one op.
@@ -566,7 +568,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
   ops = ops.drop_front();
   int sz = ops.size();
   llvm::BitVector remainingOps(sz, true);
-  do {
+  while (region) {
     int pos = -1;
     // Iterate over all remaining ops.
     while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
@@ -576,8 +578,8 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
     }
     if (remainingOps.none())
       break;
-  } while ((region = region->getParentRegion()));
-  assert(region && "could not find common parent region");
+    region = region->getParentRegion();
+  }
   return region;
 }
 
@@ -594,7 +596,8 @@ LogicalResult mlir::applyOpPatternsAndFold(
 
   // Determine scope of rewrite.
   if (!config.scope) {
-    // Compute scope if none was provided.
+    // Compute scope if none was provided. The scope will remain `nullptr` if
+    // there is a top-level op among `ops`.
     config.scope = findCommonAncestor(ops);
   } else {
     // If a scope was provided, make sure that all ops are in scope.


        


More information about the Mlir-commits mailing list