[Mlir-commits] [mlir] ba3a9f5 - [mlir:MultiOpDriver] Add operands to worklist should be checked

Chia-hung Duan llvmlistbot at llvm.org
Sat Jun 11 09:23:28 PDT 2022


Author: Chia-hung Duan
Date: 2022-06-11T15:56:23Z
New Revision: ba3a9f51ffd903edf97b0cb7d97c073d907fee30

URL: https://github.com/llvm/llvm-project/commit/ba3a9f51ffd903edf97b0cb7d97c073d907fee30
DIFF: https://github.com/llvm/llvm-project/commit/ba3a9f51ffd903edf97b0cb7d97c073d907fee30.diff

LOG: [mlir:MultiOpDriver] Add operands to worklist should be checked

Operand's defining op may not be valid for adding to the worklist under
stict mode

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D127180

Added: 
    mlir/test/Transforms/test-strict-pattern-driver.mlir

Modified: 
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 7697e603507a4..7305a376449d4 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -43,7 +43,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   bool simplify(MutableArrayRef<Region> regions);
 
   /// Add the given operation to the worklist.
-  void addToWorklist(Operation *op);
+  virtual void addToWorklist(Operation *op);
 
   /// Pop the next operation from the worklist.
   Operation *popFromWorklist();
@@ -60,8 +60,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   // be re-added to the worklist. This function should be called when an
   // operation is modified or removed, as it may trigger further
   // simplifications.
-  template <typename Operands>
-  void addToWorklist(Operands &&operands);
+  void addOperandsToWorklist(ValueRange operands);
 
   // If an operation is about to be removed, make sure it is not in our
   // worklist anymore because we'd get dangling references to it.
@@ -219,7 +218,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
       originalOperands.assign(op->operand_begin(), op->operand_end());
       auto preReplaceAction = [&](Operation *op) {
         // Add the operands to the worklist for visitation.
-        addToWorklist(originalOperands);
+        addOperandsToWorklist(originalOperands);
 
         // Add all the users of the result to the worklist so we make sure
         // to revisit them.
@@ -327,8 +326,7 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
   addToWorklist(op);
 }
 
-template <typename Operands>
-void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
+void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
   for (Value operand : operands) {
     // If the use count of this operand is now < 2, we re-add the defining
     // operation to the worklist.
@@ -343,7 +341,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
 }
 
 void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
-  addToWorklist(op->getOperands());
+  addOperandsToWorklist(op->getOperands());
   op->walk([this](Operation *operation) {
     removeFromWorklist(operation);
     folder.notifyRemoval(operation);
@@ -523,22 +521,12 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 
   bool simplifyLocally(ArrayRef<Operation *> op);
 
-private:
-  // Look over the provided operands for any defining operations that should
-  // be re-added to the worklist. This function should be called when an
-  // operation is modified or removed, as it may trigger further
-  // simplifications. If `strict` is set to true, only ops in
-  // `strictModeFilteredOps` are considered.
-  template <typename Operands>
-  void addOperandsToWorklist(Operands &&operands) {
-    for (Value operand : operands) {
-      if (auto *defOp = operand.getDefiningOp()) {
-        if (!strictMode || strictModeFilteredOps.contains(defOp))
-          addToWorklist(defOp);
-      }
-    }
+  void addToWorklist(Operation *op) override {
+    if (!strictMode || strictModeFilteredOps.contains(op))
+      GreedyPatternRewriteDriver::addToWorklist(op);
   }
 
+private:
   void notifyOperationInserted(Operation *op) override {
     GreedyPatternRewriteDriver::notifyOperationInserted(op);
     if (strictMode)
@@ -551,15 +539,6 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
       strictModeFilteredOps.erase(op);
   }
 
-  void notifyRootReplaced(Operation *op) override {
-    for (auto result : op->getResults()) {
-      for (auto *user : result.getUsers()) {
-        if (!strictMode || strictModeFilteredOps.contains(user))
-          addToWorklist(user);
-      }
-    }
-  }
-
   /// If `strictMode` is true, any pre-existing ops outside of
   /// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
   /// If `strictMode` is false, operations that use results of (or supply
@@ -633,22 +612,17 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
 
       // Add all the users of the result to the worklist so we make sure
       // to revisit them.
-      for (Value result : op->getResults())
-        for (Operation *userOp : result.getUsers()) {
-          if (!strictMode || strictModeFilteredOps.contains(userOp))
-            addToWorklist(userOp);
-        }
+      for (Value result : op->getResults()) {
+        for (Operation *userOp : result.getUsers())
+          addToWorklist(userOp);
+      }
+
       notifyOperationRemoved(op);
     };
 
     // Add the given operation generated by the folder to the worklist.
     auto processGeneratedConstants = [this](Operation *op) {
-      // Newly created ops are also simplified -- these are also "local".
-      addToWorklist(op);
-      // When strict mode is off, we don't need to maintain
-      // strictModeFilteredOps.
-      if (strictMode)
-        strictModeFilteredOps.insert(op);
+      notifyOperationInserted(op);
     };
 
     // Try to fold this op.

diff  --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
new file mode 100644
index 0000000000000..51d296935a97b
--- /dev/null
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s
+
+// CHECK-LABEL: @test_erase
+func.func @test_erase() {
+  %0 = "test.arg0"() : () -> (i32)
+  %1 = "test.arg1"() : () -> (i32)
+  %erase = "test.erase_op"(%0, %1) : (i32, i32) -> (i32)
+  return
+}
+
+// CHECK-LABEL: @test_insert_same_op
+func.func @test_insert_same_op() {
+  %0 = "test.insert_same_op"() : () -> (i32)
+  return
+}
+
+// CHECK-LABEL: @test_replace_with_same_op
+func.func @test_replace_with_same_op() {
+  %0 = "test.replace_with_same_op"() : () -> (i32)
+  %1 = "test.dummy_user"(%0) : (i32) -> (i32)
+  %2 = "test.dummy_user"(%0) : (i32) -> (i32)
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 264b118c8956d..d23f69ded8f99 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -176,6 +176,91 @@ struct TestPatternDriver
       llvm::cl::desc("Seed the worklist in general top-down order"),
       llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
 };
+
+struct TestStrictPatternDriver
+    : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
+
+  TestStrictPatternDriver() = default;
+  TestStrictPatternDriver(const TestStrictPatternDriver &other)
+      : PassWrapper(other) {}
+
+  StringRef getArgument() const final { return "test-strict-pattern-driver"; }
+  StringRef getDescription() const final {
+    return "Run strict mode of pattern driver";
+  }
+
+  void runOnOperation() override {
+    mlir::RewritePatternSet patterns(&getContext());
+    patterns.add<InsertSameOp, ReplaceWithSameOp, EraseOp>(&getContext());
+    SmallVector<Operation *> ops;
+    getOperation()->walk([&](Operation *op) {
+      StringRef opName = op->getName().getStringRef();
+      if (opName == "test.insert_same_op" ||
+          opName == "test.replace_with_same_op" || opName == "test.erase_op") {
+        ops.push_back(op);
+      }
+    });
+
+    // Check if these transformations introduce visiting of operations that
+    // are not in the `ops` set (The new created ops are valid). An invalid
+    // operation will trigger the assertion while processing.
+    (void)applyOpPatternsAndFold(makeArrayRef(ops), std::move(patterns),
+                                 /*strict=*/true);
+  }
+
+private:
+  // New inserted operation is valid for further transformation.
+  class InsertSameOp : public RewritePattern {
+  public:
+    InsertSameOp(MLIRContext *context)
+        : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
+
+    LogicalResult matchAndRewrite(Operation *op,
+                                  PatternRewriter &rewriter) const override {
+      if (op->hasAttr("skip"))
+        return failure();
+
+      Operation *newOp =
+          rewriter.create(op->getLoc(), op->getName().getIdentifier(),
+                          op->getOperands(), op->getResultTypes());
+      op->setAttr("skip", rewriter.getBoolAttr(true));
+      newOp->setAttr("skip", rewriter.getBoolAttr(true));
+
+      return success();
+    }
+  };
+
+  // Replace an operation may introduce the re-visiting of its users.
+  class ReplaceWithSameOp : public RewritePattern {
+  public:
+    ReplaceWithSameOp(MLIRContext *context)
+        : RewritePattern("test.replace_with_same_op", /*benefit=*/1, context) {}
+
+    LogicalResult matchAndRewrite(Operation *op,
+                                  PatternRewriter &rewriter) const override {
+      Operation *newOp =
+          rewriter.create(op->getLoc(), op->getName().getIdentifier(),
+                          op->getOperands(), op->getResultTypes());
+      rewriter.replaceOp(op, newOp->getResults());
+      return success();
+    }
+  };
+
+  // Remove an operation may introduce the re-visiting of its opreands.
+  class EraseOp : public RewritePattern {
+  public:
+    EraseOp(MLIRContext *context)
+        : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
+    LogicalResult matchAndRewrite(Operation *op,
+                                  PatternRewriter &rewriter) const override {
+      rewriter.eraseOp(op);
+      return success();
+    }
+  };
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1471,6 +1556,7 @@ void registerPatternsTestPass() {
   PassRegistration<TestDerivedAttributeDriver>();
 
   PassRegistration<TestPatternDriver>();
+  PassRegistration<TestStrictPatternDriver>();
 
   PassRegistration<TestLegalizePatternDriver>([] {
     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);


        


More information about the Mlir-commits mailing list