[Mlir-commits] [mlir] 0e47355 - [mlir] Fix worklist bug in MultiOpPatternRewriteDriver
Matthias Springer
llvmlistbot at llvm.org
Tue Jan 10 06:39:43 PST 2023
Author: Matthias Springer
Date: 2023-01-10T15:33:22+01:00
New Revision: 0e4735546e6bbcfd5d11d0a6b8b68cb9ccad9b41
URL: https://github.com/llvm/llvm-project/commit/0e4735546e6bbcfd5d11d0a6b8b68cb9ccad9b41
DIFF: https://github.com/llvm/llvm-project/commit/0e4735546e6bbcfd5d11d0a6b8b68cb9ccad9b41.diff
LOG: [mlir] Fix worklist bug in MultiOpPatternRewriteDriver
When `strict = true`, only pre-existing and newly-created ops are rewritten and/or folded. Such ops are stored in `strictModeFilteredOps`.
Newly-created ops were previously added to `strictModeFilteredOps` after calling `addToWorklist` (via `GreedyPatternRewriteDriver::notifyOperationInserted`). Therefore, newly-created ops were never added to the worklist.
Also fix a test case that should have gone into an infinite loop (`test.replace_with_new_op` was replaced with itself, which should have caused the op to be rewritten over and over), but did not due to this bug.
Differential Revision: https://reviews.llvm.org/D141141
Added:
Modified:
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/Transforms/test-strict-pattern-driver.mlir
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 5005a08bc29bb..cdb0b78c7a74e 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -558,9 +558,9 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
private:
void notifyOperationInserted(Operation *op) override {
- GreedyPatternRewriteDriver::notifyOperationInserted(op);
if (strictMode)
strictModeFilteredOps.insert(op);
+ GreedyPatternRewriteDriver::notifyOperationInserted(op);
}
void notifyOperationRemoved(Operation *op) override {
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 51d296935a97b..8c6eaf345d92d 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -1,6 +1,9 @@
// RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s
-// CHECK-LABEL: @test_erase
+// CHECK-LABEL: func @test_erase
+// CHECK: test.arg0
+// CHECK: test.arg1
+// CHECK-NOT: test.erase_op
func.func @test_erase() {
%0 = "test.arg0"() : () -> (i32)
%1 = "test.arg1"() : () -> (i32)
@@ -8,16 +11,29 @@ func.func @test_erase() {
return
}
-// CHECK-LABEL: @test_insert_same_op
+// CHECK-LABEL: func @test_insert_same_op
+// CHECK: "test.insert_same_op"() {skip = true}
+// CHECK: "test.insert_same_op"() {skip = true}
func.func @test_insert_same_op() {
%0 = "test.insert_same_op"() : () -> (i32)
return
}
-// CHECK-LABEL: @test_replace_with_same_op
-func.func @test_replace_with_same_op() {
- %0 = "test.replace_with_same_op"() : () -> (i32)
+// CHECK-LABEL: func @test_replace_with_new_op
+// CHECK: %[[n:.*]] = "test.new_op"
+// CHECK: "test.dummy_user"(%[[n]])
+// CHECK: "test.dummy_user"(%[[n]])
+func.func @test_replace_with_new_op() {
+ %0 = "test.replace_with_new_op"() : () -> (i32)
%1 = "test.dummy_user"(%0) : (i32) -> (i32)
%2 = "test.dummy_user"(%0) : (i32) -> (i32)
return
}
+
+// CHECK-LABEL: func @test_replace_with_erase_op
+// CHECK-NOT: test.replace_with_new_op
+// CHECK-NOT: test.erase_op
+func.func @test_replace_with_erase_op() {
+ "test.replace_with_new_op"() {create_erase_op} : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 9b74e808506f1..2573f76deb691 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -220,12 +220,12 @@ struct TestStrictPatternDriver
void runOnOperation() override {
mlir::RewritePatternSet patterns(&getContext());
- patterns.add<InsertSameOp, ReplaceWithSameOp, EraseOp>(&getContext());
+ patterns.add<InsertSameOp, ReplaceWithNewOp, EraseOp>(&getContext());
SmallVector<Operation *> ops;
getOperation()->walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef();
if (opName == "test.insert_same_op" ||
- opName == "test.replace_with_same_op" || opName == "test.erase_op") {
+ opName == "test.replace_with_new_op" || opName == "test.erase_op") {
ops.push_back(op);
}
});
@@ -260,16 +260,25 @@ struct TestStrictPatternDriver
};
// Replace an operation may introduce the re-visiting of its users.
- class ReplaceWithSameOp : public RewritePattern {
+ class ReplaceWithNewOp : public RewritePattern {
public:
- ReplaceWithSameOp(MLIRContext *context)
- : RewritePattern("test.replace_with_same_op", /*benefit=*/1, context) {}
+ ReplaceWithNewOp(MLIRContext *context)
+ : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- Operation *newOp =
- rewriter.create(op->getLoc(), op->getName().getIdentifier(),
- op->getOperands(), op->getResultTypes());
+ Operation *newOp;
+ if (op->hasAttr("create_erase_op")) {
+ newOp = rewriter.create(
+ op->getLoc(),
+ OperationName("test.erase_op", op->getContext()).getIdentifier(),
+ ValueRange(), TypeRange());
+ } else {
+ newOp = rewriter.create(
+ op->getLoc(),
+ OperationName("test.new_op", op->getContext()).getIdentifier(),
+ op->getOperands(), op->getResultTypes());
+ }
rewriter.replaceOp(op, newOp->getResults());
return success();
}
More information about the Mlir-commits
mailing list