[Mlir-commits] [mlir] [mlir][Transforms] `GreedyPatternRewriteDriver`: Check for out-of-scope IR modifications (PR #76219)

Matthias Springer llvmlistbot at llvm.org
Thu Jan 11 05:28:58 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/76219

>From ce33a69e786fab3f29bf6b3f2369f9b8cdc9dc4d Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 11 Jan 2024 10:42:12 +0000
Subject: [PATCH] [mlir][Transforms] `GreedyPatternRewriteDriver`: Check for
 out-of-scope IR modifications

This commit adds an additional "expensive check" (only enabled with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`) that looks for out-of-scope IR modifications.

`GreedyRewriteConfig::scope` specifies the `Region *` within which the greedy pattern rewrite operates. Operations that are out-of-scope are not added to the worklist. The new expensive check triggers an fatal error if:
* Op is inserted into out-of-scope region.
* Op is removed from out-of-scope region.
* Op is modified in out-of-scope region.

This change also tightens the greedy pattern rewriter entry points and makes sure that the specified `scope` is an `IsolatedFromAbove` region.

Note: `TileAllocation` (`ArmSME` dialect) must now be a module pass because it modifies `func.func` ops (adds attributes). This is forbidden for function passes (in which the scope of the greedy rewrite is set to the region of the function by default) because only function bodies are allowed to be modified. (TODO: Should we allow this? Is there something special about functions?)
---
 .../Transforms/GreedyPatternRewriteDriver.h   |  35 ++++--
 .../Utils/GreedyPatternRewriteDriver.cpp      | 103 +++++++++++++++---
 2 files changed, 111 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146aac15b9c..431608c1f71c81 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -60,10 +60,29 @@ class GreedyRewriteConfig {
 
   static constexpr int64_t kNoLimit = -1;
 
-  /// Only ops within the scope are added to the worklist. If no scope is
-  /// specified, the closest enclosing region around the initial list of ops
-  /// (or the specified region, depending on which greedy rewrite entry point
-  /// is used) is used as a scope.
+  /// Only ops within the scope are allowed to be modified and are added to the
+  /// worklist.
+  ///
+  /// If out-of-scope IR is modified, an assertion will fail inside the greedy
+  /// pattern rewrite driver if expensive checks are enabled (as long as rewrite
+  /// patterns use the rewriter API correctly). We also allow attribute
+  /// modifications of the op that owns the scope region. (This is consistent
+  /// with the fact that passes are allowed to modify attributes of the
+  /// operation that they operate on.)
+  ///
+  /// The scope region must be isolated from above. This ensures that
+  /// out-of-scope ops are not affected by rewrites.
+  ///
+  /// If no scope is specified, it is set as follows:
+  /// * Single op greedy rewrite: a greedy rewrite is performed for every region
+  ///   of the op. (See below.) The scope is set to the respective region of
+  ///   each greedy write.
+  /// * Multi op greedy rewrite: the closest enclosing IsolatedFromAbove region
+  ///   around the initial list of ops. If there is no such region, the scope
+  ///   is `nullptr`. This is because multi-op greedy rewrites are allowed to
+  ///   modify top-level ops. (They are not allowed to erase top-level ops.)
+  /// * Single region greedy rewrite: the specified region. (The op that owns
+  ///   the region must be isolated from above.)
   Region *scope = nullptr;
 
   /// Strict mode can restrict the ops that are added to the worklist during
@@ -124,11 +143,9 @@ applyPatternsAndFoldGreedily(Region &region,
 /// This overload runs a separate greedy rewrite for each region of the
 /// specified op. A region scope can be set in the configuration parameter. By
 /// default, the scope is set to the region of the current greedy rewrite. Only
-/// in-scope ops are added to the worklist and only in-scope ops and the
-/// specified op itself are allowed to be modified by the patterns.
-///
-/// Note: The specified op may be modified, but it may not be removed by the
-/// patterns.
+/// in-scope ops are added to the worklist and only in-scope ops are allowed to
+/// be modified by the patterns. In addition, the attributes of the op that
+/// owns the scope region may also be modified.
 ///
 /// Returns "success" if the iterative process converged (i.e., fixpoint was
 /// reached) and no more patterns can be matched within the region. `changed`
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 67c2d9d59f4c92..c9a49094fac3df 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -324,6 +324,11 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
 
 private:
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  /// Return "true" if the given op is guaranteed to be out of scope.
+  bool isOutOfScope(Operation *op) const;
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
   /// Look over the provided operands for any defining operations that should
   /// be re-added to the worklist. This function should be called when an
   /// operation is modified or removed, as it may trigger further
@@ -375,6 +380,28 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 }
 
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+bool GreedyPatternRewriteDriver::isOutOfScope(Operation *op) const {
+  // No op is out of scope if no scope was set.
+  if (!config.scope)
+    return false;
+  // Check if the given op and the scope region are part of the same IR tree.
+  // The parent op into which the given op was inserted may be unlinked, in
+  // which case we do not consider the given op to be out of scope. (That parent
+  // op will likely be inserted later, together with all its nested ops.)
+  Region *r = config.scope;
+  while (r) {
+    if (r->findAncestorOpInRegion(*op) || r->getParentOp() == op)
+      break;
+    r = r->getParentRegion();
+  }
+  if (!r)
+    return false;
+  // Op is out of scope if it is not within the scope region.
+  return !config.scope->findAncestorOpInRegion(*op);
+}
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
 bool GreedyPatternRewriteDriver::processWorklist() {
 #ifndef NDEBUG
   const char *logLineComment =
@@ -579,6 +606,8 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
         addSingleOpToWorklist(op);
       return;
     }
+    // TODO: Unlinked ops are currently not added to the worklist if a `scope`
+    // is specified.
     if (region == nullptr)
       return;
   } while ((op = region->getParentOp()));
@@ -600,6 +629,13 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
   });
+
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (config.scope && isOutOfScope(op))
+    llvm::report_fatal_error(
+        "greedy pattern rewrite inserted op into region that is out of scope");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
   if (config.listener)
     config.listener->notifyOperationInserted(op);
   if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
@@ -608,10 +644,24 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
 }
 
 void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
+  // TODO: This notification should also be triggered when moving an op into
+  // this op.
   LLVM_DEBUG({
     logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
                        << ")\n";
   });
+
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (config.scope) {
+    // Modifying attributes of the op that owns the scope region is allowed
+    // when using the applyPatternsAndFoldGreedily(Operation *) entry point.
+    if (op != config.scope->getParentOp() && isOutOfScope(op)) {
+      llvm::report_fatal_error("greedy pattern rewrite modified op within "
+                               "region that is out of scope");
+    }
+  }
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
   if (config.listener)
     config.listener->notifyOperationModified(op);
   addToWorklist(op);
@@ -637,16 +687,11 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
                        << ")\n";
   });
 
-#ifndef NDEBUG
-  // Only ops that are within the configured scope are added to the worklist of
-  // the greedy pattern rewriter. Moreover, the parent op of the scope region is
-  // the part of the IR that is taken into account for the "expensive checks".
-  // A greedy pattern rewrite is not allowed to erase the parent op of the scope
-  // region, as that would break the worklist handling and the expensive checks.
-  if (config.scope && config.scope->getParentOp() == op)
-    llvm_unreachable(
-        "scope region must not be erased during greedy pattern rewrite");
-#endif // NDEBUG
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (config.scope && isOutOfScope(op))
+    llvm::report_fatal_error(
+        "greedy pattern rewrite removed op from region that is out of scope");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   if (config.listener)
     config.listener->notifyOperationRemoved(op);
@@ -800,16 +845,22 @@ LogicalResult
 mlir::applyPatternsAndFoldGreedily(Region &region,
                                    const FrozenRewritePatternSet &patterns,
                                    GreedyRewriteConfig config, bool *changed) {
-  // The top-level operation must be known to be isolated from above to
-  // prevent performing canonicalizations on operations defined at or above
-  // the region containing 'op'.
-  assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
-         "patterns can only be applied to operations IsolatedFromAbove");
-
   // Set scope if not specified.
   if (!config.scope)
     config.scope = ®ion;
 
+  // Make sure that the specified region on which the greedy rewrite should
+  // operate is in scope.
+  assert(config.scope->isAncestor(&region) && "input region must be in scope");
+
+  // The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that
+  // are out of scope are never added to the worklist and any out-of-scope IR
+  // modifications trigger an assertion when expensive expensive checks are
+  // enabled (as long as the rewriter API is used correctly).
+  assert(
+      config.scope->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+      "greedy pattern rewrite scope must be IsolatedFromAbove");
+
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   if (failed(verify(config.scope->getParentOp())))
     llvm::report_fatal_error(
@@ -886,7 +937,8 @@ LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
   return success(worklist.empty());
 }
 
-/// Find the region that is the closest common ancestor of all given ops.
+/// Find the IsolateFromAbove region that is the closest common ancestor of all
+/// given ops.
 ///
 /// Note: This function returns `nullptr` if there is a top-level op among the
 /// given list of ops.
@@ -896,6 +948,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
   if (ops.size() == 1)
     return ops.front()->getParentRegion();
 
+  // Find the closest region that contains all ops.
   Region *region = ops.front()->getParentRegion();
   ops = ops.drop_front();
   int sz = ops.size();
@@ -912,6 +965,12 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
       break;
     region = region->getParentRegion();
   }
+
+  // Find the closest IsolatedFromAbove region.
+  while (region &&
+         !region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+    region = region->getParentRegion();
+
   return region;
 }
 
@@ -932,8 +991,16 @@ LogicalResult mlir::applyOpPatternsAndFold(
     // there is a top-level op among `ops`.
     config.scope = findCommonAncestor(ops);
   } else {
-    // If a scope was provided, make sure that all ops are in scope.
+    // If a scope was provided, make sure that it is IsolatedFromAbove and that
+    // all ops are in scope.
 #ifndef NDEBUG
+    // The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that
+    // are out of scope are never added to the worklist and any out-of-scope IR
+    // modifications trigger an assertion when expensive expensive checks are
+    // enabled (as long as the rewriter API is used correctly).
+    assert(
+        config.scope->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+        "greedy pattern rewrite scope must be IsolatedFromAbove");
     bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
       return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
     });



More information about the Mlir-commits mailing list