[Mlir-commits] [mlir] [mlir][Transforms][NFC] Remove `SplitBlockRewrite` (PR #82777)
Matthias Springer
llvmlistbot at llvm.org
Tue Feb 27 06:11:32 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/82777
>From 3216b262a9d95ab33ecdd59215b7d2c3100273fc Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 27 Feb 2024 13:39:32 +0000
Subject: [PATCH 1/2] [mlir][Transforms] Keep track of nested ignored/replaced
ops
The dialect conversion maintains sets of "ignored" and "replaced" ops. This change simplifies the two sets, such that all nested ops are included. (This was previously not the case and sometimes only the parent op was included.)
This change allows for more aggressive assertions to prevent incorrect rewriter API usage. E.g., accessing ops/blocks/regions within an erased op.
A concrete example: I have seen conversion patterns in downstream projects where an op is replaced with a new op, and the region of the old op is afterwards inlined into the newly created op. This is invalid rewriter API usage: ops that were replaced/erased should not be accessed. Nested ops will be considered "ignored", even if they are moved to a different region after the region's parent op was erased (which is illegal API usage). Instead, create a new op, inline the regions, then replace the old op with the new op.
BEGIN_PUBLIC
No commit message needed for presubmit.
END_PUBLIC
---
.../Transforms/Utils/DialectConversion.cpp | 90 +++++++++++--------
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 1 -
2 files changed, 52 insertions(+), 39 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f967e8352bf4c8..8c26c0b60f8a02 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -798,13 +798,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
PatternRewriter &rewriter, ValueRange values,
SmallVectorImpl<Value> &remapped);
- /// Returns true if the given operation is ignored, and does not need to be
+ /// Return "true" if the given operation is ignored, and does not need to be
/// converted.
bool isOpIgnored(Operation *op) const;
- /// Recursively marks the nested operations under 'op' as ignored. This
- /// removes them from being considered for legalization.
- void markNestedOpsIgnored(Operation *op);
+ /// Return "true" if the given operation was replaced or erased.
+ bool wasOpReplaced(Operation *op) const;
//===--------------------------------------------------------------------===//
// Type Conversion
@@ -946,18 +945,15 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Ordered list of block operations (creations, splits, motions).
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
- /// A set of operations that should no longer be considered for legalization,
- /// but were not directly replace/erased/etc. by a pattern. These are
- /// generally child operations of other operations who were
- /// replaced/erased/etc. This is not meant to be an exhaustive list of all
- /// operations, but the minimal set that can be used to detect if a given
- /// operation should be `ignored`. For example, we may add the operations that
- /// define non-empty regions to the set, but not any of the others. This
- /// simplifies the amount of memory needed as we can query if the parent
- /// operation was ignored.
+ /// A set of operations that should no longer be considered for legalization.
+ /// E.g., ops that are recursively legal. Ops that were replaced/erased are
+ /// tracked separately.
SetVector<Operation *> ignoredOps;
- // A set of operations that were erased.
+ /// A set of operations that were replaced/erased. Such ops are not erased
+ /// immediately but only when the dialect conversion succeeds. In the mean
+ /// time, they should no longer be considered for legalization and any attempt
+ /// to modify/access them is invalid rewriter API usage.
SetVector<Operation *> replacedOps;
/// The current type converter, or nullptr if no type converter is currently
@@ -1237,24 +1233,14 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
return success();
}
-// TODO: This function is a misnomer. It does not actually check if `op` is in
-// `ignoredOps`.
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
- // Check to see if this operation or the parent operation is ignored.
- return ignoredOps.count(op->getParentOp()) || replacedOps.count(op);
+ // Check to see if this operation is ignored or was replaced.
+ return replacedOps.count(op) || ignoredOps.count(op);
}
-void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
- // Walk this operation and collect nested operations that define non-empty
- // regions. We mark such operations as 'ignored' so that we know we don't have
- // to convert them, or their nested ops.
- if (op->getNumRegions() == 0)
- return;
- op->walk([&](Operation *op) {
- if (llvm::any_of(op->getRegions(),
- [](Region ®ion) { return !region.empty(); }))
- ignoredOps.insert(op);
- });
+bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
+ // Check to see if this operation was replaced.
+ return replacedOps.count(op);
}
//===----------------------------------------------------------------------===//
@@ -1476,6 +1462,9 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
+ assert(!wasOpReplaced(op) &&
+ "attempting to insert into a block within a replaced/erased op");
+
if (!previous.isSet()) {
// This is a newly created op.
appendRewrite<CreateOperationRewrite>(op);
@@ -1490,7 +1479,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
- assert(!replacedOps.contains(op) && "operation was already replaced");
+ assert(!ignoredOps.contains(op) && "operation was already replaced");
// Track if any of the results changed, e.g. erased and replaced with null.
bool resultChanged = false;
@@ -1509,10 +1498,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
resultChanged);
- // Mark this operation as recursively ignored so that we don't need to
- // convert any nested operations.
- replacedOps.insert(op);
- markNestedOpsIgnored(op);
+ // Mark this operation and all nested ops as replaced.
+ op->walk([&](Operation *op) { replacedOps.insert(op); });
}
void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
@@ -1604,6 +1591,9 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
}
void ConversionPatternRewriter::eraseBlock(Block *block) {
+ assert(!impl->wasOpReplaced(block->getParentOp()) &&
+ "attempting to erase a block within a replaced/erased op");
+
// Mark all ops for erasure.
for (Operation &op : *block)
eraseOp(&op);
@@ -1619,18 +1609,27 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
Block *ConversionPatternRewriter::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
+ assert(!impl->wasOpReplaced(region->getParentOp()) &&
+ "attempting to apply a signature conversion to a block within a "
+ "replaced/erased op");
return impl->applySignatureConversion(region, conversion, converter);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
+ assert(!impl->wasOpReplaced(region->getParentOp()) &&
+ "attempting to apply a signature conversion to a block within a "
+ "replaced/erased op");
return impl->convertRegionTypes(region, converter, entryConversion);
}
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
+ assert(!impl->wasOpReplaced(region->getParentOp()) &&
+ "attempting to apply a signature conversion to a block within a "
+ "replaced/erased op");
return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
}
@@ -1665,6 +1664,8 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {
+ assert(!impl->wasOpReplaced(block->getParentOp()) &&
+ "attempting to split a block within a replaced/erased op");
auto *continuation = block->splitBlock(before);
impl->notifySplitBlock(block, continuation);
return continuation;
@@ -1673,15 +1674,19 @@ Block *ConversionPatternRewriter::splitBlock(Block *block,
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
Block::iterator before,
ValueRange argValues) {
+#ifndef NDEBUG
assert(argValues.size() == source->getNumArguments() &&
"incorrect # of argument replacement values");
-#ifndef NDEBUG
+ assert(!impl->wasOpReplaced(source->getParentOp()) &&
+ "attempting to inline a block from a replaced/erased op");
+ assert(!impl->wasOpReplaced(dest->getParentOp()) &&
+ "attempting to inline a block into a replaced/erased op");
auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
-#endif // NDEBUG
// The source block will be deleted, so it should not have any users (i.e.,
// there should be no predecessors).
assert(llvm::all_of(source->getUsers(), opIgnored) &&
"expected 'source' to have no predecessors");
+#endif // NDEBUG
impl->notifyBlockBeingInlined(dest, source, before);
for (auto it : llvm::zip(source->getArguments(), argValues))
@@ -1691,6 +1696,8 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
}
void ConversionPatternRewriter::startOpModification(Operation *op) {
+ assert(!impl->wasOpReplaced(op) &&
+ "attempting to modify a replaced/erased op");
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
#endif
@@ -1698,6 +1705,8 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
}
void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
+ assert(!impl->wasOpReplaced(op) &&
+ "attempting to modify a replaced/erased op");
PatternRewriter::finalizeOpModification(op);
// There is nothing to do here, we only need to track the operation at the
// start of the update.
@@ -1912,8 +1921,13 @@ OperationLegalizer::legalize(Operation *op,
// If this operation is recursively legal, mark its children as ignored so
// that we don't consider them for legalization.
- if (legalityInfo->isRecursivelyLegal)
- rewriter.getImpl().markNestedOpsIgnored(op);
+ if (legalityInfo->isRecursivelyLegal) {
+ op->walk([&](Operation *nested) {
+ if (op != nested)
+ rewriter.getImpl().ignoredOps.insert(nested);
+ });
+ }
+
return success();
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index bde4255ee4b368..abc0e43c7b7f2d 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1768,7 +1768,6 @@ struct TestMergeSingleBlockOps
rewriter.inlineBlockBefore(&innerBlock, op);
rewriter.eraseOp(innerTerminator);
rewriter.eraseOp(op);
- rewriter.modifyOpInPlace(op, [] {});
return success();
}
};
>From b85493382a41792ddba2654b7d789a55a564d1c0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 27 Feb 2024 14:10:07 +0000
Subject: [PATCH 2/2] [mlir][Transforms][NFC] Remove `SplitBlockRewrite`
When splitting a block during a dialect conversion, a `SplitBlockRewrite` object is stored in the dialect conversion state. This commit removes `SplitBlockRewrite`. Instead, a combination of `CreateBlockRewrite` and multiple `MoveOperationRewrite` is used.
This change simplifies the internal state of the dialect conversion and is also needed to properly support listeners.
`RewriteBase::splitBlock` is now no longer virtual. All necessary information for committing/rolling back a split block rewrite can be deduced from `Listener::notifyBlockInserted` and `Listener::notifyOperationInserted` (which is also called when moving an operation).
---
mlir/include/mlir/IR/PatternMatch.h | 2 +-
.../mlir/Transforms/DialectConversion.h | 3 --
.../Transforms/Utils/DialectConversion.cpp | 42 -------------------
3 files changed, 1 insertion(+), 46 deletions(-)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ce3bc3fc2e783..f8d22cfb22afd0 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -579,7 +579,7 @@ class RewriterBase : public OpBuilder {
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
- virtual Block *splitBlock(Block *block, Block::iterator before);
+ Block *splitBlock(Block *block, Block::iterator before);
/// Unlink this operation from its current block and insert it right before
/// `existingOp` which may be in the same or another block in the same
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 7e8e67a9d17824..84396529eb7c2e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -741,9 +741,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// implemented for dialect conversion.
void eraseBlock(Block *block) override;
- /// PatternRewriter hook for splitting a block into two parts.
- Block *splitBlock(Block *block, Block::iterator before) override;
-
/// PatternRewriter hook for inlining the ops of a block into another block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
ValueRange argValues = std::nullopt) override;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8c26c0b60f8a02..f6e2c10dadf8d7 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -192,7 +192,6 @@ class IRRewrite {
EraseBlock,
InlineBlock,
MoveBlock,
- SplitBlock,
BlockTypeConversion,
ReplaceBlockArg,
// Operation rewrites
@@ -400,30 +399,6 @@ class MoveBlockRewrite : public BlockRewrite {
Block *insertBeforeBlock;
};
-/// Splitting of a block. This rewrite is immediately reflected in the IR.
-class SplitBlockRewrite : public BlockRewrite {
-public:
- SplitBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
- Block *originalBlock)
- : BlockRewrite(Kind::SplitBlock, rewriterImpl, block),
- originalBlock(originalBlock) {}
-
- static bool classof(const IRRewrite *rewrite) {
- return rewrite->getKind() == Kind::SplitBlock;
- }
-
- void rollback() override {
- // Merge back the block that was split out.
- originalBlock->getOperations().splice(originalBlock->end(),
- block->getOperations());
- eraseBlock(block);
- }
-
-private:
- // The original block from which this block was split.
- Block *originalBlock;
-};
-
/// This structure contains the information pertaining to an argument that has
/// been converted.
struct ConvertedArgInfo {
@@ -883,9 +858,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) override;
- /// Notifies that a block was split.
- void notifySplitBlock(Block *block, Block *continuation);
-
/// Notifies that a block is being inlined into another block.
void notifyBlockBeingInlined(Block *block, Block *srcBlock,
Block::iterator before);
@@ -1519,11 +1491,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
}
-void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
- Block *continuation) {
- appendRewrite<SplitBlockRewrite>(continuation, block);
-}
-
void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
Block *block, Block *srcBlock, Block::iterator before) {
appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
@@ -1662,15 +1629,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
results);
}
-Block *ConversionPatternRewriter::splitBlock(Block *block,
- Block::iterator before) {
- assert(!impl->wasOpReplaced(block->getParentOp()) &&
- "attempting to split a block within a replaced/erased op");
- auto *continuation = block->splitBlock(before);
- impl->notifySplitBlock(block, continuation);
- return continuation;
-}
-
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
Block::iterator before,
ValueRange argValues) {
More information about the Mlir-commits
mailing list