[Mlir-commits] [mlir] 2fea658 - [mlir] GreedyPatternRewriter: Reprocess modified ops
Matthias Springer
llvmlistbot at llvm.org
Fri Nov 18 02:44:06 PST 2022
Author: Matthias Springer
Date: 2022-11-18T11:43:44+01:00
New Revision: 2fea658a74196a9c9128be34c5bc306eba7a025e
URL: https://github.com/llvm/llvm-project/commit/2fea658a74196a9c9128be34c5bc306eba7a025e
DIFF: https://github.com/llvm/llvm-project/commit/2fea658a74196a9c9128be34c5bc306eba7a025e.diff
LOG: [mlir] GreedyPatternRewriter: Reprocess modified ops
Ops that were modifed in-place (`finalizeRootUpdate` was called) should be reprocessed by the GreedyPatternRewriter. This is currently not happening with `GreedyRewriteConfig::maxIterations = 1`.
Note: If your project goes into an infinite loop because of this change, you likely have one or multiple faulty patterns that modify the same operations in-place (`updateRootInplace`) indefinitely.
Differential Revision: https://reviews.llvm.org/D138038
Added:
mlir/test/IR/greedy-pattern-rewriter-driver.mlir
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 386f296e76a66..52f2b83c05d5b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -525,7 +525,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
auto inBoundsAttr = b.getBoolArrayAttr(bools);
if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
- xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+ b.updateRootInPlace(xferOp, [&]() {
+ xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+ });
return success();
}
@@ -596,7 +598,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
- xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+ b.updateRootInPlace(xferOp, [&]() {
+ xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+ });
return success();
}
@@ -623,7 +627,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
else
createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
- xferOp->erase();
+ b.eraseOp(xferOp);
return success();
}
@@ -634,11 +638,5 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
failed(filter(xferOp)))
return failure();
- rewriter.startRootUpdate(xferOp);
- if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
- rewriter.finalizeRootUpdate(xferOp);
- return success();
- }
- rewriter.cancelRootUpdate(xferOp);
- return failure();
+ return splitFullAndPartialTransfer(rewriter, xferOp, options);
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 9c62d61fe5291..935ca2eb93740 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -51,6 +51,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
/// If the specified operation is in the worklist, remove it.
void removeFromWorklist(Operation *op);
+ /// Notifies the driver that the specified operation may have been modified
+ /// in-place.
+ void finalizeRootUpdate(Operation *op) override;
+
protected:
// Implement the hook for inserting operations, and make sure that newly
// inserted ops are added to the worklist for processing.
@@ -326,6 +330,14 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
addToWorklist(op);
}
+void GreedyPatternRewriteDriver::finalizeRootUpdate(Operation *op) {
+ LLVM_DEBUG({
+ logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
+ << ")\n";
+ });
+ addToWorklist(op);
+}
+
void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
for (Value operand : operands) {
// If the use count of this operand is now < 2, we re-add the defining
diff --git a/mlir/test/IR/greedy-pattern-rewriter-driver.mlir b/mlir/test/IR/greedy-pattern-rewriter-driver.mlir
new file mode 100644
index 0000000000000..4f1a06fa6cf21
--- /dev/null
+++ b/mlir/test/IR/greedy-pattern-rewriter-driver.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s -test-patterns="max-iterations=1" | FileCheck %s
+
+// CHECK-LABEL: func @add_to_worklist_after_inplace_update()
+func.func @add_to_worklist_after_inplace_update() {
+ // The following op is updated in-place and should be added back to the
+ // worklist of the GreedyPatternRewriteDriver (regardless of the value of
+ // config.max_iterations).
+
+ // CHECK: "test.any_attr_of_i32_str"() {attr = 3 : i32} : () -> ()
+ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 12f374777936c..2d2bf8d71286e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -147,6 +147,26 @@ struct FolderCommutativeOp2WithConstant
}
};
+/// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
+/// attribute with value smaller than MaxVal, it increments the value by 1.
+template <int MaxVal>
+struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
+ using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AnyAttrOfOp op,
+ PatternRewriter &rewriter) const override {
+ auto intAttr = op.getAttr().dyn_cast<IntegerAttr>();
+ if (!intAttr)
+ return failure();
+ int64_t val = intAttr.getInt();
+ if (val >= MaxVal)
+ return failure();
+ rewriter.updateRootInPlace(
+ op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -165,8 +185,12 @@ struct TestPatternDriver
FolderInsertBeforePreviouslyFoldedConstantPattern,
FolderCommutativeOp2WithConstant>(&getContext());
+ // Additional patterns for testing the GreedyPatternRewriteDriver.
+ patterns.insert<IncrementIntAttribute<3>>(&getContext());
+
GreedyRewriteConfig config;
config.useTopDownTraversal = this->useTopDownTraversal;
+ config.maxIterations = this->maxIterations;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
}
@@ -175,6 +199,10 @@ struct TestPatternDriver
*this, "top-down",
llvm::cl::desc("Seed the worklist in general top-down order"),
llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
+ Option<int> maxIterations{
+ *this, "max-iterations",
+ llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
+ llvm::cl::init(GreedyRewriteConfig().maxIterations)};
};
struct TestStrictPatternDriver
More information about the Mlir-commits
mailing list