[Mlir-commits] [mlir] 774416b - [mlir] GreedyPatternRewriteDriver: Keep track of surviving ops
Matthias Springer
llvmlistbot at llvm.org
Thu Jan 26 00:22:00 PST 2023
Author: Matthias Springer
Date: 2023-01-26T09:21:51+01:00
New Revision: 774416bdb396e79e19d9c7ed49a49203c841c7c3
URL: https://github.com/llvm/llvm-project/commit/774416bdb396e79e19d9c7ed49a49203c841c7c3
DIFF: https://github.com/llvm/llvm-project/commit/774416bdb396e79e19d9c7ed49a49203c841c7c3.diff
LOG: [mlir] GreedyPatternRewriteDriver: Keep track of surviving ops
This change adds `allErased` to the `applyOpPatternsAndFold(ArrayRef<Operation *>, ...)` overload. This overload now supports all functionality that is also supported by `applyOpPatternsAndFold(Operation *, ...)` and can be used as a replacement.
This change has no performance implications when `allErased = nullptr`.
The single-operation overload is removed in a subsequent NFC change.
Differential Revision: https://reviews.llvm.org/D141920
Added:
Modified:
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/Transforms/test-strict-pattern-driver.mlir
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 72b24754fac27..f72dbb7ff2986 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -118,10 +118,12 @@ LogicalResult applyOpPatternsAndFold(Operation *op,
///
/// Returns success if the iterative process converged and no more patterns can
/// be matched. `changed` is set to true if the IR was modified at all.
+/// `allOpsErased` is set to true if all ops in `ops` were erased.
LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
GreedyRewriteStrictness strictMode,
- bool *changed = nullptr);
+ bool *changed = nullptr,
+ bool *allErased = nullptr);
} // namespace mlir
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 56a94663e43e9..2b3a796dee93f 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -17,6 +17,7 @@
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ScopedPrinter.h"
@@ -584,8 +585,11 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
///
/// Note that ops in `ops` could be erased as a result of folding, becoming
/// dead, or via pattern rewrites. The return value indicates convergence.
- LogicalResult simplifyLocally(ArrayRef<Operation *> op,
- bool *changed = nullptr);
+ ///
+ /// All `ops` that survived the rewrite are stored in `surviving`.
+ LogicalResult
+ simplifyLocally(ArrayRef<Operation *> ops, bool *changed = nullptr,
+ llvm::SmallDenseSet<Operation *, 4> *surviving = nullptr);
void addToWorklist(Operation *op) override {
if (strictMode == GreedyRewriteStrictness::AnyOp ||
@@ -602,6 +606,8 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
void notifyOperationRemoved(Operation *op) override {
GreedyPatternRewriteDriver::notifyOperationRemoved(op);
+ if (survivingOps)
+ survivingOps->erase(op);
if (strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
}
@@ -615,13 +621,25 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
/// depending on `strictMode`. This set is not maintained when `strictMode`
/// is GreedyRewriteStrictness::AnyOp.
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
+
+ /// An optional set of ops that survived the rewrite. This set is populated
+ /// at the beginning of `simplifyLocally` with the inititally provided list
+ /// of ops.
+ llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr;
};
} // namespace
-LogicalResult
-MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
- bool *changed) {
+LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
+ ArrayRef<Operation *> ops, bool *changed,
+ llvm::SmallDenseSet<Operation *, 4> *surviving) {
+ auto cleanup = llvm::make_scope_exit([&]() { survivingOps = nullptr; });
+ if (surviving) {
+ survivingOps = surviving;
+ survivingOps->clear();
+ survivingOps->insert(ops.begin(), ops.end());
+ }
+
if (strictMode != GreedyRewriteStrictness::AnyOp) {
strictModeFilteredOps.clear();
strictModeFilteredOps.insert(ops.begin(), ops.end());
@@ -726,15 +744,22 @@ LogicalResult mlir::applyOpPatternsAndFold(
LogicalResult mlir::applyOpPatternsAndFold(
ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
- GreedyRewriteStrictness strictMode, bool *changed) {
+ GreedyRewriteStrictness strictMode, bool *changed, bool *allErased) {
if (ops.empty()) {
if (changed)
*changed = false;
+ if (allErased)
+ *allErased = true;
return success();
}
// Start the pattern driver.
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
strictMode);
- return driver.simplifyLocally(ops, changed);
+ llvm::SmallDenseSet<Operation *, 4> surviving;
+ LogicalResult converged =
+ driver.simplifyLocally(ops, changed, allErased ? &surviving : nullptr);
+ if (allErased)
+ *allErased = surviving.empty();
+ return converged;
}
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index ad6f6a5f01e20..9dbaea18967f8 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -7,6 +7,7 @@
// RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-EX
// CHECK-EN-LABEL: func @test_erase
+// CHECK-EN-SAME: pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN: test.arg0
// CHECK-EN: test.arg1
// CHECK-EN-NOT: test.erase_op
@@ -20,6 +21,7 @@ func.func @test_erase() {
// -----
// CHECK-EN-LABEL: func @test_insert_same_op
+// CHECK-EN-SAME: {pattern_driver_all_erased = false, pattern_driver_changed = true}
// CHECK-EN: "test.insert_same_op"() {skip = true}
// CHECK-EN: "test.insert_same_op"() {skip = true}
func.func @test_insert_same_op() {
@@ -30,6 +32,7 @@ func.func @test_insert_same_op() {
// -----
// CHECK-EN-LABEL: func @test_replace_with_new_op
+// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN: %[[n:.*]] = "test.new_op"
// CHECK-EN: "test.dummy_user"(%[[n]])
// CHECK-EN: "test.dummy_user"(%[[n]])
@@ -43,10 +46,12 @@ func.func @test_replace_with_new_op() {
// -----
// CHECK-EN-LABEL: func @test_replace_with_erase_op
+// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN-NOT: test.replace_with_new_op
// CHECK-EN-NOT: test.erase_op
// CHECK-EX-LABEL: func @test_replace_with_erase_op
+// CHECK-EX-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EX-NOT: test.replace_with_new_op
// CHECK-EX: test.erase_op
func.func @test_replace_with_erase_op() {
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 726d8100c4c65..98896c736a3cb 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -254,8 +254,9 @@ struct TestStrictPatternDriver
}
void runOnOperation() override {
- mlir::RewritePatternSet patterns(&getContext());
- patterns.add<InsertSameOp, ReplaceWithNewOp, EraseOp>(&getContext());
+ MLIRContext *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+ patterns.add<InsertSameOp, ReplaceWithNewOp, EraseOp>(ctx);
SmallVector<Operation *> ops;
getOperation()->walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef();
@@ -279,7 +280,14 @@ struct TestStrictPatternDriver
// Check if these transformations introduce visiting of operations that
// are not in the `ops` set (The new created ops are valid). An invalid
// operation will trigger the assertion while processing.
- (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode);
+ bool changed = false;
+ bool allErased = false;
+ (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode,
+ &changed, &allErased);
+ Builder b(ctx);
+ getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
+ getOperation()->setAttr("pattern_driver_all_erased",
+ b.getBoolAttr(allErased));
}
Option<std::string> strictMode{
More information about the Mlir-commits
mailing list