[Mlir-commits] [mlir] [mlir][Transforms] Add listener support to dialect conversion (PR #83425)
Matthias Springer
llvmlistbot at llvm.org
Mon Mar 4 15:32:56 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/83425
>From 4ea3fc3f686937d2f3e2f23c34197343aae2c0be Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 4 Mar 2024 23:09:20 +0000
Subject: [PATCH 1/2] [mlir][Transforms][NFC] Make signature conversion more
efficient
During block signature conversion, a new block is inserted and ops are moved from the old block to the new block. This commit changes the implementation such that ops are moved in bulk (`splice`) instead of one-by-one; that's what `splitBlock` is doing.
This also makes it possible to pass the new block argument types directly to `createBlock` instead of using `addArgument` (which bypasses the rewriter). This doesn't change anything from a technical point of view (there is no rewriter API for adding arguments at the moment), but the implementation reads a bit nicer.
---
.../Transforms/Utils/DialectConversion.cpp | 27 ++++++++++---------
1 file changed, 15 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 9dc806730d01a1..3cfa2a66633e1a 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1281,7 +1281,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
- MLIRContext *ctx = rewriter.getContext();
+ OpBuilder::InsertionGuard g(rewriter);
// If no arguments are being changed or added, there is nothing to do.
unsigned origArgCount = block->getNumArguments();
@@ -1289,14 +1289,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (llvm::equal(block->getArgumentTypes(), convertedTypes))
return block;
- // Split the block at the beginning to get a new block to use for the updated
- // signature.
- Block *newBlock = rewriter.splitBlock(block, block->begin());
- block->replaceAllUsesWith(newBlock);
-
- // Map all new arguments to the location of the argument they originate from.
+ // Compute the locations of all block arguments in the new block.
SmallVector<Location> newLocs(convertedTypes.size(),
- Builder(ctx).getUnknownLoc());
+ rewriter.getUnknownLoc());
for (unsigned i = 0; i < origArgCount; ++i) {
auto inputMap = signatureConversion.getInputMapping(i);
if (!inputMap || inputMap->replacementValue)
@@ -1306,9 +1301,16 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
newLocs[inputMap->inputNo + j] = origLoc;
}
- SmallVector<Value, 4> newArgRange(
- newBlock->addArguments(convertedTypes, newLocs));
- ArrayRef<Value> newArgs(newArgRange);
+ // Insert a new block with the converted block argument types and move all ops
+ // from the old block to the new block.
+ Block *newBlock =
+ rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
+ convertedTypes, newLocs);
+ appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
+ newBlock->getOperations().splice(newBlock->end(), block->getOperations());
+
+ // Replace all uses of the old block with the new block.
+ block->replaceAllUsesWith(newBlock);
// Remap each of the original arguments as determined by the signature
// conversion.
@@ -1333,7 +1335,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
}
// Otherwise, this is a 1->1+ mapping.
- auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
+ auto replArgs =
+ newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
Value newArg;
// If this is a 1->1 mapping and the types of new and replacement arguments
>From 9b4aa86168227de6a0bd2739b62922dd8f931afb Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 4 Mar 2024 23:28:46 +0000
Subject: [PATCH 2/2] [mlir][Transforms] Add listener support to dialect
conversion
---
.../mlir/Transforms/DialectConversion.h | 33 +++
.../Transforms/Utils/DialectConversion.cpp | 225 ++++++++++++++----
mlir/test/Transforms/test-legalizer.mlir | 71 +++++-
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 28 ++-
4 files changed, 302 insertions(+), 55 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 84396529eb7c2e..b92357ef2046d7 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1091,6 +1091,39 @@ struct ConversionConfig {
/// IR during an analysis conversion and only pre-existing operations are
/// added to the set.
DenseSet<Operation *> *legalizableOps = nullptr;
+
+ /// An optional listener that is notified about all IR modifications in case
+ /// dialect conversion succeeds. If the dialect conversion fails and no IR
+ /// modifications are visible (i.e., they were all rolled back), no
+ /// notifications are sent.
+ ///
+ /// Note: Notifications are sent in a delayed fashion, when the dialect
+ /// conversion is guaranteed to succeed. At that point, some IR modifications
+ /// may already have been materialized. Consequently, operations/blocks that
+ /// are passed to listener callbacks should not be accessed. (Ops/blocks are
+ /// guaranteed to be valid pointers and accessing op names is allowed. But
+ /// there are no guarantees about the state of ops/blocks at the time that a
+ /// callback is triggered.)
+ ///
+ /// Example: Consider a dialect conversion a new op ("test.foo") is created
+ /// and inserted, and later moved to another block. (Moving ops also triggers
+ /// "notifyOperationInserted".)
+ ///
+ /// (1) notifyOperationInserted: "test.foo" (into block "b1")
+ /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
+ ///
+ /// When querying "op->getBlock()" during the first "notifyOperationInserted",
+ /// "b2" would be returned because "moving an op" is a kind of rewrite that is
+ /// immediately performed by the dialect conversion (and rolled back upon
+ /// failure).
+ //
+ // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
+ // callback, the previous region/block is provided to the callback, but not
+ // the iterator pointing to the exact location within the region/block. That
+ // is because these notifications are sent with a delay (after the IR has
+ // already been modified) and iterators into past IR state cannot be
+ // represented at the moment.
+ RewriterBase::Listener *listener = nullptr;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3cfa2a66633e1a..a5145246bc30c4 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -204,14 +204,22 @@ class IRRewrite {
/// Roll back the rewrite. Operations may be erased during rollback.
virtual void rollback() = 0;
- /// Commit the rewrite. Operations/blocks may be unlinked during the commit
- /// phase, but they must not be erased yet. This is because internal dialect
- /// conversion state (such as `mapping`) may still be using them. Operations/
- /// blocks must be erased during cleanup.
- virtual void commit() {}
+ /// Commit the rewrite. At this point, it is certain that the dialect
+ /// conversion will succeed. All IR modifications, except for operation/block
+ /// erasure, must be performed through the given rewriter.
+ ///
+ /// Instead of erasing operations/blocks, they should merely be unlinked
+ /// during the commit phase and finally be erased during the cleanup
+ /// phase. This is because internal dialect conversion state (such as
+ /// `mapping`) may still be using them.
+ ///
+ /// Any IR modification that was already performed before the commit phase
+ /// (e.g., insertion of an op) must be communicated to the listener that may
+ /// be attached to the given rewriter.
+ virtual void commit(RewriterBase &rewriter) {}
/// Cleanup operations/blocks. Cleanup is called after commit.
- virtual void cleanup() {}
+ virtual void cleanup(RewriterBase &rewriter) {}
Kind getKind() const { return kind; }
@@ -221,12 +229,6 @@ class IRRewrite {
IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
: kind(kind), rewriterImpl(rewriterImpl) {}
- /// Erase the given op (unless it was already erased).
- void eraseOp(Operation *op);
-
- /// Erase the given block (unless it was already erased).
- void eraseBlock(Block *block);
-
const ConversionConfig &getConfig() const;
const Kind kind;
@@ -265,6 +267,12 @@ class CreateBlockRewrite : public BlockRewrite {
return rewrite->getKind() == Kind::CreateBlock;
}
+ void commit(RewriterBase &rewriter) override {
+ // The block was already created and inserted. Just inform the listener.
+ if (auto *listener = rewriter.getListener())
+ listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
+ }
+
void rollback() override {
// Unlink all of the operations within this block, they will be deleted
// separately.
@@ -311,10 +319,19 @@ class EraseBlockRewrite : public BlockRewrite {
block = nullptr;
}
- void cleanup() override {
+ void commit(RewriterBase &rewriter) override {
// Erase the block.
assert(block && "expected block");
assert(block->empty() && "expected empty block");
+
+ // Notify the listener that the block is about to be erased.
+ if (auto *listener =
+ dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
+ listener->notifyBlockErased(block);
+ }
+
+ void cleanup(RewriterBase &rewriter) {
+ // Erase the block.
block->dropAllDefinedValueUses();
delete block;
block = nullptr;
@@ -341,6 +358,13 @@ class InlineBlockRewrite : public BlockRewrite {
firstInlinedInst(sourceBlock->empty() ? nullptr
: &sourceBlock->front()),
lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
+ // If a listener is attached to the dialect conversion, ops must be moved
+ // one-by-one. When they are moved in bulk, notifications cannot be sent
+ // because the ops that used to be in the source block at the time of the
+ // inlining (before the "commit" phase) are unknown at the time when
+ // notifications are sent (which is during the "commit" phase).
+ assert(!getConfig().listener &&
+ "InlineBlockRewrite not supported if listener is attached");
}
static bool classof(const IRRewrite *rewrite) {
@@ -382,6 +406,16 @@ class MoveBlockRewrite : public BlockRewrite {
return rewrite->getKind() == Kind::MoveBlock;
}
+ void commit(RewriterBase &rewriter) override {
+ // The block was already moved. Just inform the listener.
+ if (auto *listener = rewriter.getListener()) {
+ // Note: `previousIt` cannot be passed because this is a delayed
+ // notification and iterators into past IR state cannot be represented.
+ listener->notifyBlockInserted(block, /*previous=*/region,
+ /*previousIt=*/{});
+ }
+ }
+
void rollback() override {
// Move the block back to its original position.
Region::iterator before =
@@ -437,7 +471,7 @@ class BlockTypeConversionRewrite : public BlockRewrite {
LogicalResult
materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
- void commit() override;
+ void commit(RewriterBase &rewriter) override;
void rollback() override;
@@ -466,7 +500,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
return rewrite->getKind() == Kind::ReplaceBlockArg;
}
- void commit() override;
+ void commit(RewriterBase &rewriter) override;
void rollback() override;
@@ -506,6 +540,17 @@ class MoveOperationRewrite : public OperationRewrite {
return rewrite->getKind() == Kind::MoveOperation;
}
+ void commit(RewriterBase &rewriter) override {
+ // The operation was already moved. Just inform the listener.
+ if (auto *listener = rewriter.getListener()) {
+ // Note: `previousIt` cannot be passed because this is a delayed
+ // notification and iterators into past IR state cannot be represented.
+ listener->notifyOperationInserted(
+ op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
+ /*insertPt=*/{}));
+ }
+ }
+
void rollback() override {
// Move the operation back to its original position.
Block::iterator before =
@@ -549,7 +594,12 @@ class ModifyOperationRewrite : public OperationRewrite {
"rewrite was neither committed nor rolled back");
}
- void commit() override {
+ void commit(RewriterBase &rewriter) override {
+ // Notify the listener that the operation was modified in-place.
+ if (auto *listener =
+ dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
+ listener->notifyOperationModified(op);
+
if (propertiesStorage) {
OpaqueProperties propCopy(propertiesStorage);
// Note: The operation may have been erased in the mean time, so
@@ -600,11 +650,11 @@ class ReplaceOperationRewrite : public OperationRewrite {
return rewrite->getKind() == Kind::ReplaceOperation;
}
- void commit() override;
+ void commit(RewriterBase &rewriter) override;
void rollback() override;
- void cleanup() override;
+ void cleanup(RewriterBase &rewriter) override;
const TypeConverter *getConverter() const { return converter; }
@@ -629,6 +679,12 @@ class CreateOperationRewrite : public OperationRewrite {
return rewrite->getKind() == Kind::CreateOperation;
}
+ void commit(RewriterBase &rewriter) override {
+ // The operation was already created and inserted. Just inform the listener.
+ if (auto *listener = rewriter.getListener())
+ listener->notifyOperationInserted(op, /*previous=*/{});
+ }
+
void rollback() override;
};
@@ -666,7 +722,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
void rollback() override;
- void cleanup() override;
+ void cleanup(RewriterBase &rewriter) override;
/// Return the type converter of this materialization (which may be null).
const TypeConverter *getConverter() const {
@@ -735,7 +791,7 @@ namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
const ConversionConfig &config)
- : eraseRewriter(ctx), config(config) {}
+ : context(ctx), config(config) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -900,6 +956,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
}
void notifyOperationErased(Operation *op) override { erased.insert(op); }
+
void notifyBlockErased(Block *block) override { erased.insert(block); }
/// Pointers to all erased operations and blocks.
@@ -910,8 +967,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// State
//===--------------------------------------------------------------------===//
- /// This rewriter must be used for erasing ops/blocks.
- SingleEraseRewriter eraseRewriter;
+ /// MLIR context.
+ MLIRContext *context;
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
@@ -955,19 +1012,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
} // namespace detail
} // namespace mlir
-void IRRewrite::eraseOp(Operation *op) {
- rewriterImpl.eraseRewriter.eraseOp(op);
-}
-
-void IRRewrite::eraseBlock(Block *block) {
- rewriterImpl.eraseRewriter.eraseBlock(block);
-}
-
const ConversionConfig &IRRewrite::getConfig() const {
return rewriterImpl.config;
}
-void BlockTypeConversionRewrite::commit() {
+void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
+ // Inform the listener about all IR modifications that have already taken
+ // place: References to the original block have been replaced with the new
+ // block.
+ if (auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
+ rewriter.getListener()))
+ for (Operation *op : block->getUsers())
+ listener->notifyOperationModified(op);
+
// Process the remapping for each of the original arguments.
for (auto [origArg, info] :
llvm::zip_equal(origBlock->getArguments(), argInfo)) {
@@ -975,7 +1032,7 @@ void BlockTypeConversionRewrite::commit() {
if (!info) {
if (Value newArg =
rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
- origArg.replaceAllUsesWith(newArg);
+ rewriter.replaceAllUsesWith(origArg, newArg);
continue;
}
@@ -985,8 +1042,8 @@ void BlockTypeConversionRewrite::commit() {
// If the argument is still used, replace it with the generated cast.
if (!origArg.use_empty()) {
- origArg.replaceAllUsesWith(
- rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
+ rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
+ castValue, origArg.getType()));
}
}
}
@@ -1042,13 +1099,13 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
return success();
}
-void ReplaceBlockArgRewrite::commit() {
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
if (!repl)
return;
if (isa<BlockArgument>(repl)) {
- arg.replaceAllUsesWith(repl);
+ rewriter.replaceAllUsesWith(arg, repl);
return;
}
@@ -1057,7 +1114,7 @@ void ReplaceBlockArgRewrite::commit() {
// replacement value.
Operation *replOp = cast<OpResult>(repl).getOwner();
Block *replBlock = replOp->getBlock();
- arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+ rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
});
@@ -1065,14 +1122,40 @@ void ReplaceBlockArgRewrite::commit() {
void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
-void ReplaceOperationRewrite::commit() {
- for (OpResult result : op->getResults())
- if (Value newValue =
- rewriterImpl.mapping.lookupOrNull(result, result.getType()))
- result.replaceAllUsesWith(newValue);
+void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
+ auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
+ rewriter.getListener());
+
+ // Compute replacement values.
+ SmallVector<Value> replacements =
+ llvm::map_to_vector(op->getResults(), [&](OpResult result) {
+ return rewriterImpl.mapping.lookupOrNull(result, result.getType());
+ });
+
+ // Notify the listener that the operation is about to be replaced.
+ if (listener)
+ listener->notifyOperationReplaced(op, replacements);
+
+ // Replace all uses with the new values.
+ for (auto [result, newValue] :
+ llvm::zip_equal(op->getResults(), replacements))
+ if (newValue)
+ rewriter.replaceAllUsesWith(result, newValue);
+
+ // The original op will be erased, so remove it from the set of unlegalized
+ // ops.
if (getConfig().unlegalizedOps)
getConfig().unlegalizedOps->erase(op);
+
+ // Notify the listener that the operation (and its nested operations) was
+ // erased.
+ if (listener) {
+ op->walk<WalkOrder::PostOrder>(
+ [&](Operation *op) { listener->notifyOperationErased(op); });
+ }
+
// Do not erase the operation yet. It may still be referenced in `mapping`.
+ // Just unlink it for now and erase it during cleanup.
op->getBlock()->getOperations().remove(op);
}
@@ -1081,7 +1164,9 @@ void ReplaceOperationRewrite::rollback() {
rewriterImpl.mapping.erase(result);
}
-void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
+void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
+ rewriter.eraseOp(op);
+}
void CreateOperationRewrite::rollback() {
for (Region ®ion : op->getRegions()) {
@@ -1100,14 +1185,20 @@ void UnresolvedMaterializationRewrite::rollback() {
op->erase();
}
-void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
+void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
+ rewriter.eraseOp(op);
+}
void ConversionPatternRewriterImpl::applyRewrites() {
// Commit all rewrites.
+ IRRewriter rewriter(context, config.listener);
for (auto &rewrite : rewrites)
- rewrite->commit();
+ rewrite->commit(rewriter);
+
+ // Clean up all rewrites.
+ SingleEraseRewriter eraseRewriter(context);
for (auto &rewrite : rewrites)
- rewrite->cleanup();
+ rewrite->cleanup(eraseRewriter);
}
//===----------------------------------------------------------------------===//
@@ -1306,8 +1397,21 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
Block *newBlock =
rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
convertedTypes, newLocs);
- appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
- newBlock->getOperations().splice(newBlock->end(), block->getOperations());
+
+ // If a listener is attached to the dialect conversion, ops cannot be moved
+ // to the destination block in bulk ("fast path"). This is because at the time
+ // the notifications are sent, it is unknown which ops were moved. Instead,
+ // ops should be moved one-by-one ("slow path"), so that a separate
+ // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
+ // a bit more efficient, so we try to do that when possible.
+ bool fastPath = !config.listener;
+ if (fastPath) {
+ appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
+ newBlock->getOperations().splice(newBlock->end(), block->getOperations());
+ } else {
+ while (!block->empty())
+ rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
+ }
// Replace all uses of the old block with the new block.
block->replaceAllUsesWith(newBlock);
@@ -1652,10 +1756,31 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
"expected 'source' to have no predecessors");
#endif // NDEBUG
- impl->notifyBlockBeingInlined(dest, source, before);
+ // If a listener is attached to the dialect conversion, ops cannot be moved
+ // to the destination block in bulk ("fast path"). This is because at the time
+ // the notifications are sent, it is unknown which ops were moved. Instead,
+ // ops should be moved one-by-one ("slow path"), so that a separate
+ // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
+ // a bit more efficient, so we try to do that when possible.
+ bool fastPath = !impl->config.listener;
+
+ if (fastPath)
+ impl->notifyBlockBeingInlined(dest, source, before);
+
+ // Replace all uses of block arguments.
for (auto it : llvm::zip(source->getArguments(), argValues))
replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
- dest->getOperations().splice(before, source->getOperations());
+
+ if (fastPath) {
+ // Move all ops at once.
+ dest->getOperations().splice(before, source->getOperations());
+ } else {
+ // Move op by op.
+ while (!source->empty())
+ moveOpBefore(&source->front(), dest, before);
+ }
+
+ // Erase the source block.
eraseBlock(source);
}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index ccdc9fe78ea0d3..d552f0346644b3 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -1,5 +1,10 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics %s | FileCheck %s
+// CHECK: notifyOperationInserted: test.legal_op_a, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_a
+// CHECK-NEXT: notifyOperationModified: func.return
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_a
+
// CHECK-LABEL: verifyDirectPattern
func.func @verifyDirectPattern() -> i32 {
// CHECK-NEXT: "test.legal_op_a"() <{status = "Success"}
@@ -8,6 +13,16 @@ func.func @verifyDirectPattern() -> i32 {
return %result : i32
}
+// -----
+
+// CHECK: notifyOperationInserted: test.illegal_op_e, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_c
+// CHECK-NEXT: notifyOperationModified: func.return
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_c
+// CHECK-NEXT: notifyOperationInserted: test.legal_op_a, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_e
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_e
+
// CHECK-LABEL: verifyLargerBenefit
func.func @verifyLargerBenefit() -> i32 {
// CHECK-NEXT: "test.legal_op_a"() <{status = "Success"}
@@ -16,16 +31,24 @@ func.func @verifyLargerBenefit() -> i32 {
return %result : i32
}
+// -----
+
+// CHECK: notifyOperationModified: func.func
+// Note: No block insertion because this function is external and no block
+// signature conversion is performed.
+
// CHECK-LABEL: func private @remap_input_1_to_0()
func.func private @remap_input_1_to_0(i16)
+// -----
+
// CHECK-LABEL: func @remap_input_1_to_1(%arg0: f64)
func.func @remap_input_1_to_1(%arg0: i64) {
// CHECK-NEXT: "test.valid"{{.*}} : (f64)
"test.invalid"(%arg0) : (i64) -> ()
}
-// CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64)
+// CHECK: func @remap_call_1_to_1(%arg0: f64)
func.func @remap_call_1_to_1(%arg0: i64) {
// CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> ()
call @remap_input_1_to_1(%arg0) : (i64) -> ()
@@ -33,12 +56,36 @@ func.func @remap_call_1_to_1(%arg0: i64) {
return
}
+// -----
+
+// Block signature conversion: new block is inserted.
+// CHECK: notifyBlockInserted into func.func: was unlinked
+
+// Contents of the old block are moved to the new block.
+// CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown
+
+// The new block arguments are used in "test.return".
+// CHECK-NEXT: notifyOperationModified: test.return
+
+// The old block is erased.
+// CHECK-NEXT: notifyBlockErased
+
+// The function op gets a new type attribute.
+// CHECK-NEXT: notifyOperationModified: func.func
+
+// "test.return" is replaced.
+// CHECK-NEXT: notifyOperationInserted: test.return, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.return
+// CHECK-NEXT: notifyOperationErased: test.return
+
// CHECK-LABEL: func @remap_input_1_to_N({{.*}}f16, {{.*}}f16)
func.func @remap_input_1_to_N(%arg0: f32) -> f32 {
// CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> ()
"test.return"(%arg0) : (f32) -> ()
}
+// -----
+
// CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16)
func.func @remap_input_1_to_N_remaining_use(%arg0: f32) {
// CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32
@@ -54,6 +101,8 @@ func.func @remap_materialize_1_to_1(%arg0: i42) {
"test.return"(%arg0) : (i42) -> ()
}
+// -----
+
// CHECK-LABEL: func @remap_input_to_self
func.func @remap_input_to_self(%arg0: index) {
// CHECK-NOT: test.cast
@@ -68,6 +117,8 @@ func.func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) {
"test.invalid"(%arg0, %arg1) : (i64, i64) -> ()
}
+// -----
+
// CHECK-LABEL: func @no_remap_nested
func.func @no_remap_nested() {
// CHECK-NEXT: "foo.region"
@@ -82,6 +133,8 @@ func.func @no_remap_nested() {
return
}
+// -----
+
// CHECK-LABEL: func @remap_moved_region_args
func.func @remap_moved_region_args() {
// CHECK-NEXT: return
@@ -96,6 +149,8 @@ func.func @remap_moved_region_args() {
return
}
+// -----
+
// CHECK-LABEL: func @remap_cloned_region_args
func.func @remap_cloned_region_args() {
// CHECK-NEXT: return
@@ -122,6 +177,8 @@ func.func @remap_drop_region() {
return
}
+// -----
+
// CHECK-LABEL: func @dropped_input_in_use
func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
// CHECK-NEXT: "test.cast"{{.*}} : () -> i16
@@ -130,6 +187,8 @@ func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
"work"(%arg) : (i16) -> ()
}
+// -----
+
// CHECK-LABEL: func @up_to_date_replacement
func.func @up_to_date_replacement(%arg: i8) -> i8 {
// CHECK-NEXT: return
@@ -139,6 +198,8 @@ func.func @up_to_date_replacement(%arg: i8) -> i8 {
return %repl_2 : i8
}
+// -----
+
// CHECK-LABEL: func @remove_foldable_op
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: i32)
func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
@@ -150,6 +211,8 @@ func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
return %0 : i32
}
+// -----
+
// CHECK-LABEL: @create_block
func.func @create_block() {
// Check that we created a block with arguments.
@@ -161,6 +224,12 @@ func.func @create_block() {
return
}
+// -----
+
+// CHECK: notifyOperationModified: test.recursive_rewrite
+// CHECK-NEXT: notifyOperationModified: test.recursive_rewrite
+// CHECK-NEXT: notifyOperationModified: test.recursive_rewrite
+
// CHECK-LABEL: @bounded_recursion
func.func @bounded_recursion() {
// CHECK: test.recursive_rewrite 0
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index abc0e43c7b7f2d..2628a2784c18da 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -327,8 +327,12 @@ struct TestPatternDriver
struct DumpNotifications : public RewriterBase::Listener {
void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) override {
- llvm::outs() << "notifyBlockInserted into "
- << block->getParentOp()->getName() << ": ";
+ llvm::outs() << "notifyBlockInserted";
+ if (block->getParentOp()) {
+ llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
+ } else {
+ llvm::outs() << " into unknown op: ";
+ }
if (previous == nullptr) {
llvm::outs() << "was unlinked\n";
} else {
@@ -341,7 +345,9 @@ struct DumpNotifications : public RewriterBase::Listener {
if (!previous.isSet()) {
llvm::outs() << ", was unlinked\n";
} else {
- if (previous.getPoint() == previous.getBlock()->end()) {
+ if (!previous.getPoint().getNodePtr()) {
+ llvm::outs() << ", was linked, exact position unknown\n";
+ } else if (previous.getPoint() == previous.getBlock()->end()) {
llvm::outs() << ", was last in block\n";
} else {
llvm::outs() << ", previous = " << previous.getPoint()->getName()
@@ -349,9 +355,18 @@ struct DumpNotifications : public RewriterBase::Listener {
}
}
}
+ void notifyBlockErased(Block *block) override {
+ llvm::outs() << "notifyBlockErased\n";
+ }
void notifyOperationErased(Operation *op) override {
llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
}
+ void notifyOperationModified(Operation *op) override {
+ llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
+ }
+ void notifyOperationReplaced(Operation *op, ValueRange values) override {
+ llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
+ }
};
struct TestStrictPatternDriver
@@ -1153,6 +1168,8 @@ struct TestLegalizePatternDriver
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
ConversionConfig config;
+ DumpNotifications dumpNotifications;
+ config.listener = &dumpNotifications;
config.unlegalizedOps = &unlegalizedOps;
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns), config))) {
@@ -1171,8 +1188,11 @@ struct TestLegalizePatternDriver
return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
});
+ ConversionConfig config;
+ DumpNotifications dumpNotifications;
+ config.listener = &dumpNotifications;
if (failed(applyFullConversion(getOperation(), target,
- std::move(patterns)))) {
+ std::move(patterns), config))) {
getOperation()->emitRemark() << "applyFullConversion failed";
}
return;
More information about the Mlir-commits
mailing list