[llvm-branch-commits] [mlir] [mlir][Transforms] Support `moveOpBefore`/`After` in dialect conversion (PR #81240)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Feb 9 02:01:35 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/81240
Add a new rewrite action for "operation movements". This action can roll back `moveOpBefore` and `moveOpAfter`.
`RewriterBase::moveOpBefore` and `RewriterBase::moveOpAfter` is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.)
>From 7503c0cb484c54249ff66c5780197d46937c660d Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 9 Feb 2024 09:58:46 +0000
Subject: [PATCH] [mlir][Transforms] Support `moveOpBefore`/`After` in dialect
conversion
Add a new rewrite action for "operation movements". This action can roll back `moveOpBefore` and `moveOpAfter`.
`RewriterBase::moveOpBefore` and `RewriterBase::moveOpAfter` is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.)
---
mlir/include/mlir/IR/PatternMatch.h | 6 +-
.../mlir/Transforms/DialectConversion.h | 5 --
.../Transforms/Utils/DialectConversion.cpp | 74 +++++++++++++++----
mlir/test/Transforms/test-legalizer.mlir | 14 ++++
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 20 ++++-
5 files changed, 93 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 78dcfe7f6fc3d2..b8aeea0d23475b 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -588,8 +588,7 @@ class RewriterBase : public OpBuilder {
/// Unlink this operation from its current block and insert it right before
/// `iterator` in the specified block.
- virtual void moveOpBefore(Operation *op, Block *block,
- Block::iterator iterator);
+ void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
/// Unlink this operation from its current block and insert it right after
/// `existingOp` which may be in the same or another block in the same
@@ -598,8 +597,7 @@ class RewriterBase : public OpBuilder {
/// Unlink this operation from its current block and insert it right after
/// `iterator` in the specified block.
- virtual void moveOpAfter(Operation *op, Block *block,
- Block::iterator iterator);
+ void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
/// Unlink this block and insert it right before `existingBlock`.
void moveBlockBefore(Block *block, Block *anotherBlock);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f061d761ecefbb..c0c702a7d34821 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -738,11 +738,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
// Hide unsupported pattern rewriter API.
using OpBuilder::setListener;
- void moveOpBefore(Operation *op, Block *block,
- Block::iterator iterator) override;
- void moveOpAfter(Operation *op, Block *block,
- Block::iterator iterator) override;
-
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 44c107c8733f3d..ffdb069f6e9b81 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -757,7 +757,8 @@ class RewriteAction {
InlineBlock,
MoveBlock,
SplitBlock,
- BlockTypeConversion
+ BlockTypeConversion,
+ MoveOperation
};
virtual ~RewriteAction() = default;
@@ -970,6 +971,54 @@ class BlockTypeConversionAction : public BlockAction {
void rollback() override;
};
+
+/// An operation rewrite.
+class OperationAction : public RewriteAction {
+public:
+ /// Return the operation that this action operates on.
+ Operation *getOperation() const { return op; }
+
+ static bool classof(const RewriteAction *action) {
+ return action->getKind() >= Kind::MoveOperation &&
+ action->getKind() <= Kind::MoveOperation;
+ }
+
+protected:
+ OperationAction(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op)
+ : RewriteAction(kind, rewriterImpl), op(op) {}
+
+ // The operation that this action operates on.
+ Operation *op;
+};
+
+/// Rewrite action that represent the moving of a block.
+class MoveOperationAction : public OperationAction {
+public:
+ MoveOperationAction(ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op, Block *block, Operation *insertBeforeOp)
+ : OperationAction(Kind::MoveOperation, rewriterImpl, op), block(block),
+ insertBeforeOp(insertBeforeOp) {}
+
+ static bool classof(const RewriteAction *action) {
+ return action->getKind() == Kind::MoveOperation;
+ }
+
+ void rollback() override {
+ // Move the operation back to its original position.
+ Block::iterator before =
+ insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
+ block->getOperations().splice(before, op->getBlock()->getOperations(), op);
+ }
+
+private:
+ // The block in which this operation was previously contained.
+ Block *block;
+
+ // The original successor of this operation before it was moved. "nullptr" if
+ // this operation was the only operation in the region.
+ Operation *insertBeforeOp;
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -1468,12 +1517,19 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
void ConversionPatternRewriterImpl::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
- assert(!previous.isSet() && "expected newly created op");
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
- createdOps.push_back(op);
+ if (!previous.isSet()) {
+ // This is a newly created op.
+ createdOps.push_back(op);
+ return;
+ }
+ Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
+ ? nullptr
+ : &*previous.getPoint();
+ appendRewriteAction<MoveOperationAction>(op, previous.getBlock(), prevOp);
}
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
@@ -1712,18 +1768,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
rootUpdates.erase(rootUpdates.begin() + updateIdx);
}
-void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
- Block::iterator iterator) {
- llvm_unreachable(
- "moving single ops is not supported in a dialect conversion");
-}
-
-void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
- Block::iterator iterator) {
- llvm_unreachable(
- "moving single ops is not supported in a dialect conversion");
-}
-
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d8cf6e4719cede..84fcc18ab7d370 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -320,3 +320,17 @@ module {
return
}
}
+
+// -----
+
+// CHECK-LABEL: func @test_move_op_before_rollback()
+func.func @test_move_op_before_rollback() {
+ // CHECK: "test.one_region_op"()
+ // CHECK: "test.hoist_me"()
+ "test.one_region_op"() ({
+ // expected-remark @below{{'test.hoist_me' is not legalizable}}
+ %0 = "test.hoist_me"() : () -> (i32)
+ "test.valid"(%0) : (i32) -> ()
+ }) : () -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d7e5d6db50c1fb..1c02232b8adbb1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -773,6 +773,22 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
}
};
+/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
+/// This is to test the rollback logic.
+struct TestUndoMoveOpBefore : public ConversionPattern {
+ TestUndoMoveOpBefore(MLIRContext *ctx)
+ : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.moveOpBefore(op, op->getParentOp());
+ // Replace with an illegal op to ensure the conversion fails.
+ rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
+ return success();
+ }
+};
+
/// A rewrite pattern that tests the undo mechanism when erasing a block.
struct TestUndoBlockErase : public ConversionPattern {
TestUndoBlockErase(MLIRContext *ctx)
@@ -1069,7 +1085,7 @@ struct TestLegalizePatternDriver
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement, TestBoundedRecursiveRewrite,
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
- TestCreateUnregisteredOp>(&getContext());
+ TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext());
patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1079,7 +1095,7 @@ struct TestLegalizePatternDriver
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
- TerminatorOp>();
+ TerminatorOp, OneRegionOp>();
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
More information about the llvm-branch-commits
mailing list