[Mlir-commits] [mlir] [mlir][Transforms] `GreedyPatternRewriteDriver`: Check for out-of-scope IR modifications (PR #76219)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 22 00:51:47 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit adds an additional "expensive check" (only enabled with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`) that looks for out-of-scope IR modifications.

`GreedyRewriteConfig::scope` specifies the `Region *` within which the greedy pattern rewrite operates. Operations that are out-of-scope are not added to the worklist. The new expensive check triggers an fatal error if:
* Op is inserted into out-of-scope region.
* Op is removed from out-of-scope region.
* Op is modified in out-of-scope region.

This change also tightens the greedy pattern rewriter entry points and makes sure that the specified `scope` is an `IsolatedFromAbove` region.

Note: `TileAllocation` (`ArmSME` dialect) must now be a module pass because it modifies `func.func` ops (adds attributes). This is forbidden for function passes (in which the scope of the greedy rewrite is set to the region of the function by default) because the `func.func` op is not in scope. (We should not set the scope to be the enclosing module because the greedy pattern rewrite could then modify any function in the module.)

TODO: Should we allow this? Is there something special about functions? (They have no operands or results, so they are quite decoupled from other ops.)

---
Full diff: https://github.com/llvm/llvm-project/pull/76219.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+1-1) 
- (modified) mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h (+15-3) 
- (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+77-8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 4266ac5b0c8cf6..4d21bdf560852d 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -90,7 +90,7 @@ def EnableArmStreaming
 }
 
 def TileAllocation
-    : Pass<"allocate-arm-sme-tiles", "mlir::func::FuncOp"> {
+    : Pass<"allocate-arm-sme-tiles", "mlir::ModuleOp"> {
   let summary = "Allocate SME tiles";
   let description = [{
     This pass does tile allocation for SME "virtual tiles". It is run at the
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index b93ffd96bee5fa..fe66ba81f10cc2 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -60,9 +60,21 @@ class GreedyRewriteConfig {
 
   static constexpr int64_t kNoLimit = -1;
 
-  /// Only ops within the scope are added to the worklist. If no scope is
-  /// specified, the closest enclosing region around the initial list of ops
-  /// is used as a scope.
+  /// Only ops within the scope are allowed to be modified and are added to the
+  /// worklist. If out-of-scope IR is modified, an assertion will fail inside
+  /// the greedy pattern rewrite driver if expensive checks are enabled (as long
+  /// as rewrite patterns use the rewriter API correctly).
+  ///
+  /// The op that owns the scope region must be isolated from above. If no scope
+  /// is specified, it is set as follows:
+  /// * Single op greedy rewrite: a greedy rewrite is performed for every region
+  ///   of the op. See below. (The op must be isolated from above.)
+  /// * Multi op greedy rewrite: the closest enclosing IsolatedFromAbove region
+  ///   around the initial list of ops. If there is no such region, the scope
+  ///   is `nullptr`. This is because multi-op greedy rewrites are allowed to
+  ///   modify top-level ops. (They are not allowed to erase top-level ops.)
+  /// * Single region greedy rewrite: the specified region. (The op that owns
+  ///   the region must be isolated from above.)
   Region *scope = nullptr;
 
   /// Strict mode can restrict the ops that are added to the worklist during
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index eca13f52f53dc4..158c312dc60c82 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -327,6 +327,11 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
 
 private:
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  /// Return "true" if the given op is guaranteed to be out of scope.
+  bool isOutOfScope(Operation *op) const;
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
   /// Look over the provided operands for any defining operations that should
   /// be re-added to the worklist. This function should be called when an
   /// operation is modified or removed, as it may trigger further
@@ -378,6 +383,28 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 }
 
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+bool GreedyPatternRewriteDriver::isOutOfScope(Operation *op) const {
+  // No op is out of scope if no scope was set.
+  if (!config.scope)
+    return false;
+  // Check if the given op and the scope region are part of the same IR tree.
+  // The parent op into which the given op was inserted may be unlinked, in
+  // which case we do not consider the given op to be out of scope. (That parent
+  // op will likely be inserted later, together with all its nested ops.)
+  Region *r = config.scope;
+  while (r) {
+    if (r->findAncestorOpInRegion(*op) || r->getParentOp() == op)
+      break;
+    r = r->getParentRegion();
+  }
+  if (!r)
+    return false;
+  // Op is out of scope if it is not within the scope region.
+  return !config.scope->findAncestorOpInRegion(*op);
+}
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
 bool GreedyPatternRewriteDriver::processWorklist() {
 #ifndef NDEBUG
   const char *logLineComment =
@@ -518,6 +545,8 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
         addSingleOpToWorklist(op);
       return;
     }
+    // TODO: Unlinked ops are currently not added to the worklist if a `scope`
+    // is specified.
     if (region == nullptr)
       return;
   } while ((op = region->getParentOp()));
@@ -539,6 +568,13 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
   });
+
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (config.scope && isOutOfScope(op))
+    llvm::report_fatal_error(
+        "greedy pattern rewrite inserted op into region that is out of scope");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
   if (config.listener)
     config.listener->notifyOperationInserted(op);
   if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
@@ -547,10 +583,19 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
 }
 
 void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
+  // TODO: This notification should also be triggered when moving an op into
+  // this op.
   LLVM_DEBUG({
     logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
                        << ")\n";
   });
+
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (config.scope && isOutOfScope(op))
+    llvm::report_fatal_error("greedy pattern rewrite modified op within region "
+                             "that is out of scope");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
   if (config.listener)
     config.listener->notifyOperationModified(op);
   addToWorklist(op);
@@ -587,6 +632,12 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
         "scope region must not be erased during greedy pattern rewrite");
 #endif // NDEBUG
 
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (config.scope && isOutOfScope(op))
+    llvm::report_fatal_error(
+        "greedy pattern rewrite removed op from region that is out of scope");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
   if (config.listener)
     config.listener->notifyOperationRemoved(op);
 
@@ -736,16 +787,18 @@ LogicalResult
 mlir::applyPatternsAndFoldGreedily(Region &region,
                                    const FrozenRewritePatternSet &patterns,
                                    GreedyRewriteConfig config, bool *changed) {
-  // The top-level operation must be known to be isolated from above to
-  // prevent performing canonicalizations on operations defined at or above
-  // the region containing 'op'.
-  assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
-         "patterns can only be applied to operations IsolatedFromAbove");
-
   // Set scope if not specified.
   if (!config.scope)
     config.scope = ®ion;
 
+  // The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that
+  // are out of scope are never added to the worklist and any out-of-scope IR
+  // modifications trigger an assertion when expensive expensive checks are
+  // enabled (as long as the rewriter API is used correctly).
+  assert(
+      config.scope->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+      "greedy pattern rewrite scope must be IsolatedFromAbove");
+
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   if (failed(verify(config.scope->getParentOp())))
     llvm::report_fatal_error(
@@ -822,7 +875,8 @@ LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
   return success(worklist.empty());
 }
 
-/// Find the region that is the closest common ancestor of all given ops.
+/// Find the IsolateFromAbove region that is the closest common ancestor of all
+/// given ops.
 ///
 /// Note: This function returns `nullptr` if there is a top-level op among the
 /// given list of ops.
@@ -832,6 +886,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
   if (ops.size() == 1)
     return ops.front()->getParentRegion();
 
+  // Find the closest region that contains all ops.
   Region *region = ops.front()->getParentRegion();
   ops = ops.drop_front();
   int sz = ops.size();
@@ -848,6 +903,12 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
       break;
     region = region->getParentRegion();
   }
+
+  // Find the closest IsolatedFromAbove region.
+  while (region &&
+         !region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+    region = region->getParentRegion();
+
   return region;
 }
 
@@ -868,8 +929,16 @@ LogicalResult mlir::applyOpPatternsAndFold(
     // there is a top-level op among `ops`.
     config.scope = findCommonAncestor(ops);
   } else {
-    // If a scope was provided, make sure that all ops are in scope.
+    // If a scope was provided, make sure that it is IsolatedFromAbove and that
+    // all ops are in scope.
 #ifndef NDEBUG
+    // The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that
+    // are out of scope are never added to the worklist and any out-of-scope IR
+    // modifications trigger an assertion when expensive expensive checks are
+    // enabled (as long as the rewriter API is used correctly).
+    assert(
+        config.scope->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+        "greedy pattern rewrite scope must be IsolatedFromAbove");
     bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
       return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
     });

``````````

</details>


https://github.com/llvm/llvm-project/pull/76219


More information about the Mlir-commits mailing list