[Mlir-commits] [mlir] 2fea658 - [mlir] GreedyPatternRewriter: Reprocess modified ops

Matthias Springer llvmlistbot at llvm.org
Fri Nov 18 02:44:06 PST 2022


Author: Matthias Springer
Date: 2022-11-18T11:43:44+01:00
New Revision: 2fea658a74196a9c9128be34c5bc306eba7a025e

URL: https://github.com/llvm/llvm-project/commit/2fea658a74196a9c9128be34c5bc306eba7a025e
DIFF: https://github.com/llvm/llvm-project/commit/2fea658a74196a9c9128be34c5bc306eba7a025e.diff

LOG: [mlir] GreedyPatternRewriter: Reprocess modified ops

Ops that were modifed in-place (`finalizeRootUpdate` was called) should be reprocessed by the GreedyPatternRewriter. This is currently not happening with `GreedyRewriteConfig::maxIterations = 1`.

Note: If your project goes into an infinite loop because of this change, you likely have one or multiple faulty patterns that modify the same operations in-place (`updateRootInplace`) indefinitely.

Differential Revision: https://reviews.llvm.org/D138038

Added: 
    mlir/test/IR/greedy-pattern-rewriter-driver.mlir

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 386f296e76a66..52f2b83c05d5b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -525,7 +525,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
   SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
   auto inBoundsAttr = b.getBoolArrayAttr(bools);
   if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
-    xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+    b.updateRootInPlace(xferOp, [&]() {
+      xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+    });
     return success();
   }
 
@@ -596,7 +598,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
     for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
       xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
 
-    xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+    b.updateRootInPlace(xferOp, [&]() {
+      xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+    });
 
     return success();
   }
@@ -623,7 +627,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
   else
     createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
 
-  xferOp->erase();
+  b.eraseOp(xferOp);
 
   return success();
 }
@@ -634,11 +638,5 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
   if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
       failed(filter(xferOp)))
     return failure();
-  rewriter.startRootUpdate(xferOp);
-  if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
-    rewriter.finalizeRootUpdate(xferOp);
-    return success();
-  }
-  rewriter.cancelRootUpdate(xferOp);
-  return failure();
+  return splitFullAndPartialTransfer(rewriter, xferOp, options);
 }

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 9c62d61fe5291..935ca2eb93740 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -51,6 +51,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   /// If the specified operation is in the worklist, remove it.
   void removeFromWorklist(Operation *op);
 
+  /// Notifies the driver that the specified operation may have been modified
+  /// in-place.
+  void finalizeRootUpdate(Operation *op) override;
+
 protected:
   // Implement the hook for inserting operations, and make sure that newly
   // inserted ops are added to the worklist for processing.
@@ -326,6 +330,14 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
   addToWorklist(op);
 }
 
+void GreedyPatternRewriteDriver::finalizeRootUpdate(Operation *op) {
+  LLVM_DEBUG({
+    logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
+                       << ")\n";
+  });
+  addToWorklist(op);
+}
+
 void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
   for (Value operand : operands) {
     // If the use count of this operand is now < 2, we re-add the defining

diff  --git a/mlir/test/IR/greedy-pattern-rewriter-driver.mlir b/mlir/test/IR/greedy-pattern-rewriter-driver.mlir
new file mode 100644
index 0000000000000..4f1a06fa6cf21
--- /dev/null
+++ b/mlir/test/IR/greedy-pattern-rewriter-driver.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s -test-patterns="max-iterations=1" | FileCheck %s
+
+// CHECK-LABEL: func @add_to_worklist_after_inplace_update()
+func.func @add_to_worklist_after_inplace_update() {
+  // The following op is updated in-place and should be added back to the
+  // worklist of the GreedyPatternRewriteDriver (regardless of the value of
+  // config.max_iterations).
+
+  // CHECK: "test.any_attr_of_i32_str"() {attr = 3 : i32} : () -> ()
+  "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 12f374777936c..2d2bf8d71286e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -147,6 +147,26 @@ struct FolderCommutativeOp2WithConstant
   }
 };
 
+/// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
+/// attribute with value smaller than MaxVal, it increments the value by 1.
+template <int MaxVal>
+struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
+  using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AnyAttrOfOp op,
+                                PatternRewriter &rewriter) const override {
+    auto intAttr = op.getAttr().dyn_cast<IntegerAttr>();
+    if (!intAttr)
+      return failure();
+    int64_t val = intAttr.getInt();
+    if (val >= MaxVal)
+      return failure();
+    rewriter.updateRootInPlace(
+        op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -165,8 +185,12 @@ struct TestPatternDriver
                  FolderInsertBeforePreviouslyFoldedConstantPattern,
                  FolderCommutativeOp2WithConstant>(&getContext());
 
+    // Additional patterns for testing the GreedyPatternRewriteDriver.
+    patterns.insert<IncrementIntAttribute<3>>(&getContext());
+
     GreedyRewriteConfig config;
     config.useTopDownTraversal = this->useTopDownTraversal;
+    config.maxIterations = this->maxIterations;
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
                                        config);
   }
@@ -175,6 +199,10 @@ struct TestPatternDriver
       *this, "top-down",
       llvm::cl::desc("Seed the worklist in general top-down order"),
       llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
+  Option<int> maxIterations{
+      *this, "max-iterations",
+      llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
+      llvm::cl::init(GreedyRewriteConfig().maxIterations)};
 };
 
 struct TestStrictPatternDriver


        


More information about the Mlir-commits mailing list