[Mlir-commits] [mlir] [mlir][Transforms] `GreedyPatternRewriteDriver`: Check for out-of-scope IR modifications (PR #76219)

Matthias Springer llvmlistbot at llvm.org
Fri Dec 22 00:51:23 PST 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/76219

This commit adds an additional "expensive check" (only enabled with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`) that looks for out-of-scope IR modifications.

`GreedyRewriteConfig::scope` specifies the `Region *` within which the greedy pattern rewrite operates. Operations that are out-of-scope are not added to the worklist. The new expensive check triggers an fatal error if:
* Op is inserted into out-of-scope region.
* Op is removed from out-of-scope region.
* Op is modified in out-of-scope region.

This change also tightens the greedy pattern rewriter entry points and makes sure that the specified `scope` is an `IsolatedFromAbove` region.

Note: `TileAllocation` (`ArmSME` dialect) must now be a module pass because it modifies `func.func` ops (adds attributes). This is forbidden for function passes (in which the scope of the greedy rewrite is set to the region of the function by default) because the `func.func` op is not in scope. (We should not set the scope to be the enclosing module because the greedy pattern rewrite could then modify any function in the module.)

TODO: Should we allow this? Is there something special about functions? (They have no operands or results, so they are quite decoupled from other ops.)

>From 67d95586a7f841159032f20c293830e00571b098 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 22 Dec 2023 17:45:36 +0900
Subject: [PATCH] [mlir][Transforms] `GreedyPatternRewriteDriver`: Check for
 out-of-scope IR modifications

This commit adds an additional "expensive check" (only enabled with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`) that looks for out-of-scope IR modifications.

`GreedyRewriteConfig::scope` specifies the `Region *` within which the greedy pattern rewrite operates. Operations that are out-of-scope are not added to the worklist. The new expensive check triggers an fatal error if:
* Op is inserted into out-of-scope region.
* Op is removed from out-of-scope region.
* Op is modified in out-of-scope region.

This change also tightens the greedy pattern rewriter entry points and makes sure that the specified `scope` is an `IsolatedFromAbove` region.

Note: `TileAllocation` (`ArmSME` dialect) must now be a module pass because it modifies `func.func` ops (adds attributes). This is forbidden for function passes (in which the scope of the greedy rewrite is set to the region of the function by default) because only function bodies are allowed to be modified. (TODO: Should we allow this? Is there something special about functions?)
---
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |  2 +-
 .../Transforms/GreedyPatternRewriteDriver.h   | 18 +++-
 .../Utils/GreedyPatternRewriteDriver.cpp      | 85 +++++++++++++++++--
 3 files changed, 93 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 4266ac5b0c8cf6..4d21bdf560852d 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -90,7 +90,7 @@ def EnableArmStreaming
 }
 
 def TileAllocation
-    : Pass<"allocate-arm-sme-tiles", "mlir::func::FuncOp"> {
+    : Pass<"allocate-arm-sme-tiles", "mlir::ModuleOp"> {
   let summary = "Allocate SME tiles";
   let description = [{
     This pass does tile allocation for SME "virtual tiles". It is run at the
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index b93ffd96bee5fa..fe66ba81f10cc2 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -60,9 +60,21 @@ class GreedyRewriteConfig {
 
   static constexpr int64_t kNoLimit = -1;
 
-  /// Only ops within the scope are added to the worklist. If no scope is
-  /// specified, the closest enclosing region around the initial list of ops
-  /// is used as a scope.
+  /// Only ops within the scope are allowed to be modified and are added to the
+  /// worklist. If out-of-scope IR is modified, an assertion will fail inside
+  /// the greedy pattern rewrite driver if expensive checks are enabled (as long
+  /// as rewrite patterns use the rewriter API correctly).
+  ///
+  /// The op that owns the scope region must be isolated from above. If no scope
+  /// is specified, it is set as follows:
+  /// * Single op greedy rewrite: a greedy rewrite is performed for every region
+  ///   of the op. See below. (The op must be isolated from above.)
+  /// * Multi op greedy rewrite: the closest enclosing IsolatedFromAbove region
+  ///   around the initial list of ops. If there is no such region, the scope
+  ///   is `nullptr`. This is because multi-op greedy rewrites are allowed to
+  ///   modify top-level ops. (They are not allowed to erase top-level ops.)
+  /// * Single region greedy rewrite: the specified region. (The op that owns
+  ///   the region must be isolated from above.)
   Region *scope = nullptr;
 
   /// Strict mode can restrict the ops that are added to the worklist during
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index eca13f52f53dc4..158c312dc60c82 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -327,6 +327,11 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
 
 private:
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  /// Return "true" if the given op is guaranteed to be out of scope.
+  bool isOutOfScope(Operation *op) const;
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
   /// Look over the provided operands for any defining operations that should
   /// be re-added to the worklist. This function should be called when an
   /// operation is modified or removed, as it may trigger further
@@ -378,6 +383,28 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 }
 
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+bool GreedyPatternRewriteDriver::isOutOfScope(Operation *op) const {
+  // No op is out of scope if no scope was set.
+  if (!config.scope)
+    return false;
+  // Check if the given op and the scope region are part of the same IR tree.
+  // The parent op into which the given op was inserted may be unlinked, in
+  // which case we do not consider the given op to be out of scope. (That parent
+  // op will likely be inserted later, together with all its nested ops.)
+  Region *r = config.scope;
+  while (r) {
+    if (r->findAncestorOpInRegion(*op) || r->getParentOp() == op)
+      break;
+    r = r->getParentRegion();
+  }
+  if (!r)
+    return false;
+  // Op is out of scope if it is not within the scope region.
+  return !config.scope->findAncestorOpInRegion(*op);
+}
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
 bool GreedyPatternRewriteDriver::processWorklist() {
 #ifndef NDEBUG
   const char *logLineComment =
@@ -518,6 +545,8 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
         addSingleOpToWorklist(op);
       return;
     }
+    // TODO: Unlinked ops are currently not added to the worklist if a `scope`
+    // is specified.
     if (region == nullptr)
       return;
   } while ((op = region->getParentOp()));
@@ -539,6 +568,13 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
   });
+
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (config.scope && isOutOfScope(op))
+    llvm::report_fatal_error(
+        "greedy pattern rewrite inserted op into region that is out of scope");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
   if (config.listener)
     config.listener->notifyOperationInserted(op);
   if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
@@ -547,10 +583,19 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
 }
 
 void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
+  // TODO: This notification should also be triggered when moving an op into
+  // this op.
   LLVM_DEBUG({
     logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
                        << ")\n";
   });
+
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (config.scope && isOutOfScope(op))
+    llvm::report_fatal_error("greedy pattern rewrite modified op within region "
+                             "that is out of scope");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
   if (config.listener)
     config.listener->notifyOperationModified(op);
   addToWorklist(op);
@@ -587,6 +632,12 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
         "scope region must not be erased during greedy pattern rewrite");
 #endif // NDEBUG
 
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (config.scope && isOutOfScope(op))
+    llvm::report_fatal_error(
+        "greedy pattern rewrite removed op from region that is out of scope");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
   if (config.listener)
     config.listener->notifyOperationRemoved(op);
 
@@ -736,16 +787,18 @@ LogicalResult
 mlir::applyPatternsAndFoldGreedily(Region &region,
                                    const FrozenRewritePatternSet &patterns,
                                    GreedyRewriteConfig config, bool *changed) {
-  // The top-level operation must be known to be isolated from above to
-  // prevent performing canonicalizations on operations defined at or above
-  // the region containing 'op'.
-  assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
-         "patterns can only be applied to operations IsolatedFromAbove");
-
   // Set scope if not specified.
   if (!config.scope)
     config.scope = ®ion;
 
+  // The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that
+  // are out of scope are never added to the worklist and any out-of-scope IR
+  // modifications trigger an assertion when expensive expensive checks are
+  // enabled (as long as the rewriter API is used correctly).
+  assert(
+      config.scope->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+      "greedy pattern rewrite scope must be IsolatedFromAbove");
+
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   if (failed(verify(config.scope->getParentOp())))
     llvm::report_fatal_error(
@@ -822,7 +875,8 @@ LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
   return success(worklist.empty());
 }
 
-/// Find the region that is the closest common ancestor of all given ops.
+/// Find the IsolateFromAbove region that is the closest common ancestor of all
+/// given ops.
 ///
 /// Note: This function returns `nullptr` if there is a top-level op among the
 /// given list of ops.
@@ -832,6 +886,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
   if (ops.size() == 1)
     return ops.front()->getParentRegion();
 
+  // Find the closest region that contains all ops.
   Region *region = ops.front()->getParentRegion();
   ops = ops.drop_front();
   int sz = ops.size();
@@ -848,6 +903,12 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
       break;
     region = region->getParentRegion();
   }
+
+  // Find the closest IsolatedFromAbove region.
+  while (region &&
+         !region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+    region = region->getParentRegion();
+
   return region;
 }
 
@@ -868,8 +929,16 @@ LogicalResult mlir::applyOpPatternsAndFold(
     // there is a top-level op among `ops`.
     config.scope = findCommonAncestor(ops);
   } else {
-    // If a scope was provided, make sure that all ops are in scope.
+    // If a scope was provided, make sure that it is IsolatedFromAbove and that
+    // all ops are in scope.
 #ifndef NDEBUG
+    // The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that
+    // are out of scope are never added to the worklist and any out-of-scope IR
+    // modifications trigger an assertion when expensive expensive checks are
+    // enabled (as long as the rewriter API is used correctly).
+    assert(
+        config.scope->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+        "greedy pattern rewrite scope must be IsolatedFromAbove");
     bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
       return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
     });



More information about the Mlir-commits mailing list