[Mlir-commits] [mlir] 724a0e2 - [mlir] GreedyPatternRewriteDriver: Ignore scope when rewriting top-level ops
Matthias Springer
llvmlistbot at llvm.org
Fri Feb 3 00:57:05 PST 2023
Author: Matthias Springer
Date: 2023-02-03T09:56:55+01:00
New Revision: 724a0e2c2d7a5724dd81b00db470ba4bb8b616ca
URL: https://github.com/llvm/llvm-project/commit/724a0e2c2d7a5724dd81b00db470ba4bb8b616ca
DIFF: https://github.com/llvm/llvm-project/commit/724a0e2c2d7a5724dd81b00db470ba4bb8b616ca.diff
LOG: [mlir] GreedyPatternRewriteDriver: Ignore scope when rewriting top-level ops
Top-level ModuleOps cannot be transformed with the GreedyPatternRewriteDriver since D141945 because they do not have an enclosing region that could be used as a scope. Make the scope optional inside GreedyPatternRewriteDriver, so that top-level ops can be processed when they are on the initial list of ops.
Note: This does not allow users to bypass the scoping mechanism by setting `config.scope = nullptr`.
Fixes #60462.
Differential Revision: https://reviews.llvm.org/D143151
Added:
Modified:
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index d8c17c67357b0..423221dd80da0 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -61,7 +61,8 @@ class GreedyRewriteConfig {
static constexpr int64_t kNoLimit = -1;
/// Only ops within the scope are added to the worklist. If no scope is
- /// specified, the closest enclosing region is used as a scope.
+ /// specified, the closest enclosing region around the initial list of ops
+ /// is used as a scope.
Region *scope = nullptr;
/// Strict mode can restrict the ops that are added to the worklist during
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 4c5868aead3f1..997bdc6a1c49f 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -124,7 +124,6 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
: PatternRewriter(ctx), folder(ctx), config(config), matcher(patterns) {
- assert(config.scope && "scope is not specified");
worklist.reserve(64);
// Apply a simple cost model based solely on pattern benefit.
@@ -266,19 +265,19 @@ bool GreedyPatternRewriteDriver::processWorklist() {
void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
// Gather potential ancestors while looking for a "scope" parent region.
SmallVector<Operation *, 8> ancestors;
- ancestors.push_back(op);
- while (Region *region = op->getParentRegion()) {
- if (config.scope == region) {
- // All gathered ops are in fact ancestors.
- for (Operation *op : ancestors)
- addSingleOpToWorklist(op);
- break;
- }
- op = region->getParentOp();
- if (!op)
- break;
+ Region *region = nullptr;
+ do {
ancestors.push_back(op);
- }
+ region = op->getParentRegion();
+ if (config.scope == region) {
+ // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
+ for (Operation *op : ancestors)
+ addSingleOpToWorklist(op);
+ return;
+ }
+ if (region == nullptr)
+ return;
+ } while ((op = region->getParentOp()));
}
void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
@@ -556,6 +555,9 @@ LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
}
/// Find the region that is the closest common ancestor of all given ops.
+///
+/// Note: This function returns `nullptr` if there is a top-level op among the
+/// given list of ops.
static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
assert(!ops.empty() && "expected at least one op");
// Fast path in case there is only one op.
@@ -566,7 +568,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
ops = ops.drop_front();
int sz = ops.size();
llvm::BitVector remainingOps(sz, true);
- do {
+ while (region) {
int pos = -1;
// Iterate over all remaining ops.
while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
@@ -576,8 +578,8 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
}
if (remainingOps.none())
break;
- } while ((region = region->getParentRegion()));
- assert(region && "could not find common parent region");
+ region = region->getParentRegion();
+ }
return region;
}
@@ -594,7 +596,8 @@ LogicalResult mlir::applyOpPatternsAndFold(
// Determine scope of rewrite.
if (!config.scope) {
- // Compute scope if none was provided.
+ // Compute scope if none was provided. The scope will remain `nullptr` if
+ // there is a top-level op among `ops`.
config.scope = findCommonAncestor(ops);
} else {
// If a scope was provided, make sure that all ops are in scope.
More information about the Mlir-commits
mailing list