[Mlir-commits] [mlir] [mlir][Transforms] `GreedyPatternRewriteDriver`: Check for out-of-scope IR modifications (PR #76219)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 22 00:51:47 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit adds an additional "expensive check" (only enabled with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`) that looks for out-of-scope IR modifications.
`GreedyRewriteConfig::scope` specifies the `Region *` within which the greedy pattern rewrite operates. Operations that are out-of-scope are not added to the worklist. The new expensive check triggers an fatal error if:
* Op is inserted into out-of-scope region.
* Op is removed from out-of-scope region.
* Op is modified in out-of-scope region.
This change also tightens the greedy pattern rewriter entry points and makes sure that the specified `scope` is an `IsolatedFromAbove` region.
Note: `TileAllocation` (`ArmSME` dialect) must now be a module pass because it modifies `func.func` ops (adds attributes). This is forbidden for function passes (in which the scope of the greedy rewrite is set to the region of the function by default) because the `func.func` op is not in scope. (We should not set the scope to be the enclosing module because the greedy pattern rewrite could then modify any function in the module.)
TODO: Should we allow this? Is there something special about functions? (They have no operands or results, so they are quite decoupled from other ops.)
---
Full diff: https://github.com/llvm/llvm-project/pull/76219.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+1-1)
- (modified) mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h (+15-3)
- (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+77-8)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 4266ac5b0c8cf6..4d21bdf560852d 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -90,7 +90,7 @@ def EnableArmStreaming
}
def TileAllocation
- : Pass<"allocate-arm-sme-tiles", "mlir::func::FuncOp"> {
+ : Pass<"allocate-arm-sme-tiles", "mlir::ModuleOp"> {
let summary = "Allocate SME tiles";
let description = [{
This pass does tile allocation for SME "virtual tiles". It is run at the
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index b93ffd96bee5fa..fe66ba81f10cc2 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -60,9 +60,21 @@ class GreedyRewriteConfig {
static constexpr int64_t kNoLimit = -1;
- /// Only ops within the scope are added to the worklist. If no scope is
- /// specified, the closest enclosing region around the initial list of ops
- /// is used as a scope.
+ /// Only ops within the scope are allowed to be modified and are added to the
+ /// worklist. If out-of-scope IR is modified, an assertion will fail inside
+ /// the greedy pattern rewrite driver if expensive checks are enabled (as long
+ /// as rewrite patterns use the rewriter API correctly).
+ ///
+ /// The op that owns the scope region must be isolated from above. If no scope
+ /// is specified, it is set as follows:
+ /// * Single op greedy rewrite: a greedy rewrite is performed for every region
+ /// of the op. See below. (The op must be isolated from above.)
+ /// * Multi op greedy rewrite: the closest enclosing IsolatedFromAbove region
+ /// around the initial list of ops. If there is no such region, the scope
+ /// is `nullptr`. This is because multi-op greedy rewrites are allowed to
+ /// modify top-level ops. (They are not allowed to erase top-level ops.)
+ /// * Single region greedy rewrite: the specified region. (The op that owns
+ /// the region must be isolated from above.)
Region *scope = nullptr;
/// Strict mode can restrict the ops that are added to the worklist during
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index eca13f52f53dc4..158c312dc60c82 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -327,6 +327,11 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
private:
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ /// Return "true" if the given op is guaranteed to be out of scope.
+ bool isOutOfScope(Operation *op) const;
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
/// Look over the provided operands for any defining operations that should
/// be re-added to the worklist. This function should be called when an
/// operation is modified or removed, as it may trigger further
@@ -378,6 +383,28 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+bool GreedyPatternRewriteDriver::isOutOfScope(Operation *op) const {
+ // No op is out of scope if no scope was set.
+ if (!config.scope)
+ return false;
+ // Check if the given op and the scope region are part of the same IR tree.
+ // The parent op into which the given op was inserted may be unlinked, in
+ // which case we do not consider the given op to be out of scope. (That parent
+ // op will likely be inserted later, together with all its nested ops.)
+ Region *r = config.scope;
+ while (r) {
+ if (r->findAncestorOpInRegion(*op) || r->getParentOp() == op)
+ break;
+ r = r->getParentRegion();
+ }
+ if (!r)
+ return false;
+ // Op is out of scope if it is not within the scope region.
+ return !config.scope->findAncestorOpInRegion(*op);
+}
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
bool GreedyPatternRewriteDriver::processWorklist() {
#ifndef NDEBUG
const char *logLineComment =
@@ -518,6 +545,8 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
addSingleOpToWorklist(op);
return;
}
+ // TODO: Unlinked ops are currently not added to the worklist if a `scope`
+ // is specified.
if (region == nullptr)
return;
} while ((op = region->getParentOp()));
@@ -539,6 +568,13 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
+
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ if (config.scope && isOutOfScope(op))
+ llvm::report_fatal_error(
+ "greedy pattern rewrite inserted op into region that is out of scope");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
if (config.listener)
config.listener->notifyOperationInserted(op);
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
@@ -547,10 +583,19 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
}
void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
+ // TODO: This notification should also be triggered when moving an op into
+ // this op.
LLVM_DEBUG({
logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
<< ")\n";
});
+
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ if (config.scope && isOutOfScope(op))
+ llvm::report_fatal_error("greedy pattern rewrite modified op within region "
+ "that is out of scope");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
if (config.listener)
config.listener->notifyOperationModified(op);
addToWorklist(op);
@@ -587,6 +632,12 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
"scope region must not be erased during greedy pattern rewrite");
#endif // NDEBUG
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ if (config.scope && isOutOfScope(op))
+ llvm::report_fatal_error(
+ "greedy pattern rewrite removed op from region that is out of scope");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
if (config.listener)
config.listener->notifyOperationRemoved(op);
@@ -736,16 +787,18 @@ LogicalResult
mlir::applyPatternsAndFoldGreedily(Region ®ion,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config, bool *changed) {
- // The top-level operation must be known to be isolated from above to
- // prevent performing canonicalizations on operations defined at or above
- // the region containing 'op'.
- assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
- "patterns can only be applied to operations IsolatedFromAbove");
-
// Set scope if not specified.
if (!config.scope)
config.scope = ®ion;
+ // The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that
+ // are out of scope are never added to the worklist and any out-of-scope IR
+ // modifications trigger an assertion when expensive expensive checks are
+ // enabled (as long as the rewriter API is used correctly).
+ assert(
+ config.scope->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+ "greedy pattern rewrite scope must be IsolatedFromAbove");
+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error(
@@ -822,7 +875,8 @@ LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
return success(worklist.empty());
}
-/// Find the region that is the closest common ancestor of all given ops.
+/// Find the IsolateFromAbove region that is the closest common ancestor of all
+/// given ops.
///
/// Note: This function returns `nullptr` if there is a top-level op among the
/// given list of ops.
@@ -832,6 +886,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
if (ops.size() == 1)
return ops.front()->getParentRegion();
+ // Find the closest region that contains all ops.
Region *region = ops.front()->getParentRegion();
ops = ops.drop_front();
int sz = ops.size();
@@ -848,6 +903,12 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
break;
region = region->getParentRegion();
}
+
+ // Find the closest IsolatedFromAbove region.
+ while (region &&
+ !region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ region = region->getParentRegion();
+
return region;
}
@@ -868,8 +929,16 @@ LogicalResult mlir::applyOpPatternsAndFold(
// there is a top-level op among `ops`.
config.scope = findCommonAncestor(ops);
} else {
- // If a scope was provided, make sure that all ops are in scope.
+ // If a scope was provided, make sure that it is IsolatedFromAbove and that
+ // all ops are in scope.
#ifndef NDEBUG
+ // The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that
+ // are out of scope are never added to the worklist and any out-of-scope IR
+ // modifications trigger an assertion when expensive expensive checks are
+ // enabled (as long as the rewriter API is used correctly).
+ assert(
+ config.scope->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+ "greedy pattern rewrite scope must be IsolatedFromAbove");
bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
});
``````````
</details>
https://github.com/llvm/llvm-project/pull/76219
More information about the Mlir-commits
mailing list