[Mlir-commits] [mlir] 87e345b - [mlir] GreedyPatternRewriteDriver: Add new strict mode option
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 20 01:08:22 PST 2023
Author: Matthias Springer
Date: 2023-01-20T10:08:11+01:00
New Revision: 87e345b1bdb76867cc6e9ae59b6dd2633a480d38
URL: https://github.com/llvm/llvm-project/commit/87e345b1bdb76867cc6e9ae59b6dd2633a480d38
DIFF: https://github.com/llvm/llvm-project/commit/87e345b1bdb76867cc6e9ae59b6dd2633a480d38.diff
LOG: [mlir] GreedyPatternRewriteDriver: Add new strict mode option
There are now three options:
* `AnyOp` (previously `false`)
* `ExistingAndNewOps` (previously `true`)
* `ExistingOps`: this one is new.
The last option corresponds to what the `applyOpPatternsAndFold(Operation*, ...)` overload is doing. It is now also supported on the `applyOpPatternsAndFold(ArrayRef<Operation *>, ...)` overload.
Differential Revision: https://reviews.llvm.org/D141904
Added:
Modified:
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/Transforms/test-strict-pattern-driver.mlir
mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index aaafebf252fab..72b24754fac27 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -18,6 +18,17 @@
namespace mlir {
+/// This enum controls which ops are put on the worklist during a greedy
+/// pattern rewrite.
+enum class GreedyRewriteStrictness {
+ /// No restrictions wrt. which ops are processed.
+ AnyOp,
+ /// Only pre-existing and newly created ops are processed.
+ ExistingAndNewOps,
+ /// Only pre-existing ops are processed.
+ ExistingOps
+};
+
/// This class allows control over how the GreedyPatternRewriteDriver works.
class GreedyRewriteConfig {
public:
@@ -88,21 +99,29 @@ LogicalResult applyOpPatternsAndFold(Operation *op,
bool *erased = nullptr);
/// Applies the specified rewrite patterns on `ops` while also trying to fold
-/// these ops as well as any other ops that were in turn created due to such
-/// rewrites. Furthermore, any pre-existing ops in the IR outside of `ops`
-/// remain completely unmodified if `strict` is set to true. If `strict` is
-/// false, other operations that use results of rewritten ops or supply operands
-/// to such ops are in turn simplified; any other ops still remain unmodified
-/// (i.e., regardless of `strict`). Note that ops in `ops` could be erased as a
-/// result of folding, becoming dead, or via pattern rewrites. If more far
-/// reaching simplification is desired, applyPatternsAndFoldGreedily should be
-/// used.
+/// these ops.
+///
+/// Newly created ops and other pre-existing ops that use results of rewritten
+/// ops or supply operands to such ops are simplified, unless such ops are
+/// excluded via `strictMode`. Any other ops remain unmodified (i.e., regardless
+/// of `strictMode`).
+///
+/// * GreedyRewriteStrictness::AnyOp: No ops are excluded.
+/// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing and newly
+/// created ops are simplified. All other ops are excluded.
+/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops are
+/// simplified. All other ops are excluded.
+///
+/// Note that ops in `ops` could be erased as result of folding, becoming dead,
+/// or via pattern rewrites. If more far reaching simplification is desired,
+/// applyPatternsAndFoldGreedily should be used.
///
/// Returns success if the iterative process converged and no more patterns can
/// be matched. `changed` is set to true if the IR was modified at all.
LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
- bool strict, bool *changed = nullptr);
+ GreedyRewriteStrictness strictMode,
+ bool *changed = nullptr);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index 282a35b45c537..d516de8ee6d26 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -132,7 +132,8 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
// Apply the simplification pattern to a fixpoint.
if (failed(
- applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true))) {
+ applyOpPatternsAndFold(targets, frozenPatterns,
+ GreedyRewriteStrictness::ExistingAndNewOps))) {
auto diag = emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
return diag;
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 0d84d384012c2..a9d6f940200b0 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -239,5 +239,6 @@ void AffineDataCopyGeneration::runOnOperation() {
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- (void)applyOpPatternsAndFold(copyOps, frozenPatterns, /*strict=*/true);
+ (void)applyOpPatternsAndFold(copyOps, frozenPatterns,
+ GreedyRewriteStrictness::ExistingAndNewOps);
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index bb5b390cb2ea4..6cb0a30dce39f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -105,5 +105,6 @@ void SimplifyAffineStructures::runOnOperation() {
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
opsToSimplify.push_back(op);
});
- (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, /*strict=*/true);
+ (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns,
+ GreedyRewriteStrictness::ExistingAndNewOps);
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index b7ea592bfcc7d..56a94663e43e9 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -575,66 +575,54 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
- bool strict)
+ GreedyRewriteStrictness strictMode)
: GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
- strictMode(strict) {}
+ strictMode(strictMode) {}
+ /// Performs the specified rewrites on `ops` while also trying to fold these
+ /// ops. `strictMode` controls which other ops are simplified.
+ ///
+ /// Note that ops in `ops` could be erased as a result of folding, becoming
+ /// dead, or via pattern rewrites. The return value indicates convergence.
LogicalResult simplifyLocally(ArrayRef<Operation *> op,
bool *changed = nullptr);
void addToWorklist(Operation *op) override {
- if (!strictMode || strictModeFilteredOps.contains(op))
+ if (strictMode == GreedyRewriteStrictness::AnyOp ||
+ strictModeFilteredOps.contains(op))
GreedyPatternRewriteDriver::addSingleOpToWorklist(op);
}
private:
void notifyOperationInserted(Operation *op) override {
- if (strictMode)
+ if (strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
strictModeFilteredOps.insert(op);
GreedyPatternRewriteDriver::notifyOperationInserted(op);
}
void notifyOperationRemoved(Operation *op) override {
GreedyPatternRewriteDriver::notifyOperationRemoved(op);
- if (strictMode)
+ if (strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
}
- /// If `strictMode` is true, any pre-existing ops outside of
- /// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
- /// If `strictMode` is false, operations that use results of (or supply
- /// operands to) any rewritten ops stemming from the simplification of the
- /// provided ops are in turn simplified; any other ops still remain untouched
- /// (i.e., regardless of `strictMode`).
- bool strictMode = false;
-
- /// The list of ops we are restricting our rewrites to if `strictMode` is on.
- /// These include the supplied set of ops as well as new ops created while
- /// rewriting those ops. This set is not maintained when strictMode is off.
+ /// `strictMode` control which ops are added to the worklist during
+ /// simplification.
+ GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
+
+ /// The list of ops we are restricting our rewrites to. These include the
+ /// supplied set of ops as well as new ops created while rewriting those ops
+ /// depending on `strictMode`. This set is not maintained when `strictMode`
+ /// is GreedyRewriteStrictness::AnyOp.
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
};
} // namespace
-/// Performs the specified rewrites on `ops` while also trying to fold these ops
-/// as well as any other ops that were in turn created due to these rewrite
-/// patterns. Any pre-existing ops outside of `ops` remain completely
-/// unmodified if `strictMode` is true. If `strictMode` is false, other
-/// operations that use results of rewritten ops or supply operands to such ops
-/// are in turn simplified; any other ops still remain unmodified (i.e.,
-/// regardless of `strictMode`). Note that ops in `ops` could be erased as a
-/// result of folding, becoming dead, or via pattern rewrites. Returns true if
-/// at all any changes happened.
-// Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op
-// or GreedyPatternRewriteDriver::simplify, this method just iterates until
-// the worklist is empty. As our objective is to keep simplification "local",
-// there is no strong rationale to re-add all operations into the worklist and
-// rerun until an iteration changes nothing. If more widereaching simplification
-// is desired, GreedyPatternRewriteDriver should be used.
LogicalResult
MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
bool *changed) {
- if (strictMode) {
+ if (strictMode != GreedyRewriteStrictness::AnyOp) {
strictModeFilteredOps.clear();
strictModeFilteredOps.insert(ops.begin(), ops.end());
}
@@ -659,7 +647,8 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
if (op == nullptr)
continue;
- assert((!strictMode || strictModeFilteredOps.contains(op)) &&
+ assert((strictMode == GreedyRewriteStrictness::AnyOp ||
+ strictModeFilteredOps.contains(op)) &&
"unexpected op was inserted under strict mode");
// If the operation is trivially dead - remove it.
@@ -718,9 +707,6 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
return success(worklist.empty());
}
-/// Rewrites only `op` using the supplied canonicalization patterns and
-/// folding. `erased` is set to true if the op is erased as a result of being
-/// folded, replaced, or dead.
LogicalResult mlir::applyOpPatternsAndFold(
Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
// Start the pattern driver.
@@ -738,10 +724,9 @@ LogicalResult mlir::applyOpPatternsAndFold(
return converged;
}
-LogicalResult
-mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
- const FrozenRewritePatternSet &patterns,
- bool strict, bool *changed) {
+LogicalResult mlir::applyOpPatternsAndFold(
+ ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
+ GreedyRewriteStrictness strictMode, bool *changed) {
if (ops.empty()) {
if (changed)
*changed = false;
@@ -750,6 +735,6 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
// Start the pattern driver.
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
- strict);
+ strictMode);
return driver.simplifyLocally(ops, changed);
}
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 8c6eaf345d92d..ad6f6a5f01e20 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -1,9 +1,15 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s
+// RUN: mlir-opt \
+// RUN: -test-strict-pattern-driver="strictness=ExistingAndNewOps" \
+// RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-EN
-// CHECK-LABEL: func @test_erase
-// CHECK: test.arg0
-// CHECK: test.arg1
-// CHECK-NOT: test.erase_op
+// RUN: mlir-opt \
+// RUN: -test-strict-pattern-driver="strictness=ExistingOps" \
+// RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-EX
+
+// CHECK-EN-LABEL: func @test_erase
+// CHECK-EN: test.arg0
+// CHECK-EN: test.arg1
+// CHECK-EN-NOT: test.erase_op
func.func @test_erase() {
%0 = "test.arg0"() : () -> (i32)
%1 = "test.arg1"() : () -> (i32)
@@ -11,18 +17,22 @@ func.func @test_erase() {
return
}
-// CHECK-LABEL: func @test_insert_same_op
-// CHECK: "test.insert_same_op"() {skip = true}
-// CHECK: "test.insert_same_op"() {skip = true}
+// -----
+
+// CHECK-EN-LABEL: func @test_insert_same_op
+// CHECK-EN: "test.insert_same_op"() {skip = true}
+// CHECK-EN: "test.insert_same_op"() {skip = true}
func.func @test_insert_same_op() {
%0 = "test.insert_same_op"() : () -> (i32)
return
}
-// CHECK-LABEL: func @test_replace_with_new_op
-// CHECK: %[[n:.*]] = "test.new_op"
-// CHECK: "test.dummy_user"(%[[n]])
-// CHECK: "test.dummy_user"(%[[n]])
+// -----
+
+// CHECK-EN-LABEL: func @test_replace_with_new_op
+// CHECK-EN: %[[n:.*]] = "test.new_op"
+// CHECK-EN: "test.dummy_user"(%[[n]])
+// CHECK-EN: "test.dummy_user"(%[[n]])
func.func @test_replace_with_new_op() {
%0 = "test.replace_with_new_op"() : () -> (i32)
%1 = "test.dummy_user"(%0) : (i32) -> (i32)
@@ -30,9 +40,15 @@ func.func @test_replace_with_new_op() {
return
}
-// CHECK-LABEL: func @test_replace_with_erase_op
-// CHECK-NOT: test.replace_with_new_op
-// CHECK-NOT: test.erase_op
+// -----
+
+// CHECK-EN-LABEL: func @test_replace_with_erase_op
+// CHECK-EN-NOT: test.replace_with_new_op
+// CHECK-EN-NOT: test.erase_op
+
+// CHECK-EX-LABEL: func @test_replace_with_erase_op
+// CHECK-EX-NOT: test.replace_with_new_op
+// CHECK-EX: test.erase_op
func.func @test_replace_with_erase_op() {
"test.replace_with_new_op"() {create_erase_op} : () -> ()
return
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index 117f83e01f9ca..7dc478c8b9cf1 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -132,7 +132,8 @@ void TestAffineDataCopy::runOnOperation() {
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
}
}
- (void)applyOpPatternsAndFold(copyOps, std::move(patterns), /*strict=*/true);
+ (void)applyOpPatternsAndFold(copyOps, std::move(patterns),
+ GreedyRewriteStrictness::ExistingAndNewOps);
}
namespace mlir {
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d3ef160176cc3..286d0de72cc9a 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -244,11 +244,13 @@ struct TestStrictPatternDriver
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
TestStrictPatternDriver() = default;
- TestStrictPatternDriver(const TestStrictPatternDriver &other) = default;
+ TestStrictPatternDriver(const TestStrictPatternDriver &other) {
+ strictMode = other.strictMode;
+ }
StringRef getArgument() const final { return "test-strict-pattern-driver"; }
StringRef getDescription() const final {
- return "Run strict mode of pattern driver";
+ return "Test strict mode of pattern driver";
}
void runOnOperation() override {
@@ -263,13 +265,28 @@ struct TestStrictPatternDriver
}
});
+ GreedyRewriteStrictness mode;
+ if (strictMode == "AnyOp") {
+ mode = GreedyRewriteStrictness::AnyOp;
+ } else if (strictMode == "ExistingAndNewOps") {
+ mode = GreedyRewriteStrictness::ExistingAndNewOps;
+ } else if (strictMode == "ExistingOps") {
+ mode = GreedyRewriteStrictness::ExistingOps;
+ } else {
+ llvm_unreachable("invalid strictness option");
+ }
+
// Check if these transformations introduce visiting of operations that
// are not in the `ops` set (The new created ops are valid). An invalid
// operation will trigger the assertion while processing.
- (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns),
- /*strict=*/true);
+ (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode);
}
+ Option<std::string> strictMode{
+ *this, "strictness",
+ llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
+ llvm::cl::init("AnyOp")};
+
private:
// New inserted operation is valid for further transformation.
class InsertSameOp : public RewritePattern {
More information about the Mlir-commits
mailing list