[Mlir-commits] [mlir] ba3a9f5 - [mlir:MultiOpDriver] Add operands to worklist should be checked
Chia-hung Duan
llvmlistbot at llvm.org
Sat Jun 11 09:23:28 PDT 2022
Author: Chia-hung Duan
Date: 2022-06-11T15:56:23Z
New Revision: ba3a9f51ffd903edf97b0cb7d97c073d907fee30
URL: https://github.com/llvm/llvm-project/commit/ba3a9f51ffd903edf97b0cb7d97c073d907fee30
DIFF: https://github.com/llvm/llvm-project/commit/ba3a9f51ffd903edf97b0cb7d97c073d907fee30.diff
LOG: [mlir:MultiOpDriver] Add operands to worklist should be checked
Operand's defining op may not be valid for adding to the worklist under
stict mode
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D127180
Added:
mlir/test/Transforms/test-strict-pattern-driver.mlir
Modified:
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 7697e603507a4..7305a376449d4 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -43,7 +43,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
bool simplify(MutableArrayRef<Region> regions);
/// Add the given operation to the worklist.
- void addToWorklist(Operation *op);
+ virtual void addToWorklist(Operation *op);
/// Pop the next operation from the worklist.
Operation *popFromWorklist();
@@ -60,8 +60,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
// be re-added to the worklist. This function should be called when an
// operation is modified or removed, as it may trigger further
// simplifications.
- template <typename Operands>
- void addToWorklist(Operands &&operands);
+ void addOperandsToWorklist(ValueRange operands);
// If an operation is about to be removed, make sure it is not in our
// worklist anymore because we'd get dangling references to it.
@@ -219,7 +218,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
originalOperands.assign(op->operand_begin(), op->operand_end());
auto preReplaceAction = [&](Operation *op) {
// Add the operands to the worklist for visitation.
- addToWorklist(originalOperands);
+ addOperandsToWorklist(originalOperands);
// Add all the users of the result to the worklist so we make sure
// to revisit them.
@@ -327,8 +326,7 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
addToWorklist(op);
}
-template <typename Operands>
-void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
+void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
for (Value operand : operands) {
// If the use count of this operand is now < 2, we re-add the defining
// operation to the worklist.
@@ -343,7 +341,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
}
void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
- addToWorklist(op->getOperands());
+ addOperandsToWorklist(op->getOperands());
op->walk([this](Operation *operation) {
removeFromWorklist(operation);
folder.notifyRemoval(operation);
@@ -523,22 +521,12 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
bool simplifyLocally(ArrayRef<Operation *> op);
-private:
- // Look over the provided operands for any defining operations that should
- // be re-added to the worklist. This function should be called when an
- // operation is modified or removed, as it may trigger further
- // simplifications. If `strict` is set to true, only ops in
- // `strictModeFilteredOps` are considered.
- template <typename Operands>
- void addOperandsToWorklist(Operands &&operands) {
- for (Value operand : operands) {
- if (auto *defOp = operand.getDefiningOp()) {
- if (!strictMode || strictModeFilteredOps.contains(defOp))
- addToWorklist(defOp);
- }
- }
+ void addToWorklist(Operation *op) override {
+ if (!strictMode || strictModeFilteredOps.contains(op))
+ GreedyPatternRewriteDriver::addToWorklist(op);
}
+private:
void notifyOperationInserted(Operation *op) override {
GreedyPatternRewriteDriver::notifyOperationInserted(op);
if (strictMode)
@@ -551,15 +539,6 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
strictModeFilteredOps.erase(op);
}
- void notifyRootReplaced(Operation *op) override {
- for (auto result : op->getResults()) {
- for (auto *user : result.getUsers()) {
- if (!strictMode || strictModeFilteredOps.contains(user))
- addToWorklist(user);
- }
- }
- }
-
/// If `strictMode` is true, any pre-existing ops outside of
/// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
/// If `strictMode` is false, operations that use results of (or supply
@@ -633,22 +612,17 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
// Add all the users of the result to the worklist so we make sure
// to revisit them.
- for (Value result : op->getResults())
- for (Operation *userOp : result.getUsers()) {
- if (!strictMode || strictModeFilteredOps.contains(userOp))
- addToWorklist(userOp);
- }
+ for (Value result : op->getResults()) {
+ for (Operation *userOp : result.getUsers())
+ addToWorklist(userOp);
+ }
+
notifyOperationRemoved(op);
};
// Add the given operation generated by the folder to the worklist.
auto processGeneratedConstants = [this](Operation *op) {
- // Newly created ops are also simplified -- these are also "local".
- addToWorklist(op);
- // When strict mode is off, we don't need to maintain
- // strictModeFilteredOps.
- if (strictMode)
- strictModeFilteredOps.insert(op);
+ notifyOperationInserted(op);
};
// Try to fold this op.
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
new file mode 100644
index 0000000000000..51d296935a97b
--- /dev/null
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s
+
+// CHECK-LABEL: @test_erase
+func.func @test_erase() {
+ %0 = "test.arg0"() : () -> (i32)
+ %1 = "test.arg1"() : () -> (i32)
+ %erase = "test.erase_op"(%0, %1) : (i32, i32) -> (i32)
+ return
+}
+
+// CHECK-LABEL: @test_insert_same_op
+func.func @test_insert_same_op() {
+ %0 = "test.insert_same_op"() : () -> (i32)
+ return
+}
+
+// CHECK-LABEL: @test_replace_with_same_op
+func.func @test_replace_with_same_op() {
+ %0 = "test.replace_with_same_op"() : () -> (i32)
+ %1 = "test.dummy_user"(%0) : (i32) -> (i32)
+ %2 = "test.dummy_user"(%0) : (i32) -> (i32)
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 264b118c8956d..d23f69ded8f99 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -176,6 +176,91 @@ struct TestPatternDriver
llvm::cl::desc("Seed the worklist in general top-down order"),
llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
};
+
+struct TestStrictPatternDriver
+ : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
+
+ TestStrictPatternDriver() = default;
+ TestStrictPatternDriver(const TestStrictPatternDriver &other)
+ : PassWrapper(other) {}
+
+ StringRef getArgument() const final { return "test-strict-pattern-driver"; }
+ StringRef getDescription() const final {
+ return "Run strict mode of pattern driver";
+ }
+
+ void runOnOperation() override {
+ mlir::RewritePatternSet patterns(&getContext());
+ patterns.add<InsertSameOp, ReplaceWithSameOp, EraseOp>(&getContext());
+ SmallVector<Operation *> ops;
+ getOperation()->walk([&](Operation *op) {
+ StringRef opName = op->getName().getStringRef();
+ if (opName == "test.insert_same_op" ||
+ opName == "test.replace_with_same_op" || opName == "test.erase_op") {
+ ops.push_back(op);
+ }
+ });
+
+ // Check if these transformations introduce visiting of operations that
+ // are not in the `ops` set (The new created ops are valid). An invalid
+ // operation will trigger the assertion while processing.
+ (void)applyOpPatternsAndFold(makeArrayRef(ops), std::move(patterns),
+ /*strict=*/true);
+ }
+
+private:
+ // New inserted operation is valid for further transformation.
+ class InsertSameOp : public RewritePattern {
+ public:
+ InsertSameOp(MLIRContext *context)
+ : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ if (op->hasAttr("skip"))
+ return failure();
+
+ Operation *newOp =
+ rewriter.create(op->getLoc(), op->getName().getIdentifier(),
+ op->getOperands(), op->getResultTypes());
+ op->setAttr("skip", rewriter.getBoolAttr(true));
+ newOp->setAttr("skip", rewriter.getBoolAttr(true));
+
+ return success();
+ }
+ };
+
+ // Replace an operation may introduce the re-visiting of its users.
+ class ReplaceWithSameOp : public RewritePattern {
+ public:
+ ReplaceWithSameOp(MLIRContext *context)
+ : RewritePattern("test.replace_with_same_op", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ Operation *newOp =
+ rewriter.create(op->getLoc(), op->getName().getIdentifier(),
+ op->getOperands(), op->getResultTypes());
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+ };
+
+ // Remove an operation may introduce the re-visiting of its opreands.
+ class EraseOp : public RewritePattern {
+ public:
+ EraseOp(MLIRContext *context)
+ : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ rewriter.eraseOp(op);
+ return success();
+ }
+ };
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1471,6 +1556,7 @@ void registerPatternsTestPass() {
PassRegistration<TestDerivedAttributeDriver>();
PassRegistration<TestPatternDriver>();
+ PassRegistration<TestStrictPatternDriver>();
PassRegistration<TestLegalizePatternDriver>([] {
return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
More information about the Mlir-commits
mailing list