[Mlir-commits] [mlir] aaf5c81 - [mlir][Transforms][NFC] Simplify `BlockTypeConversionRewrite` (#83286)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 4 01:06:04 PST 2024
Author: Matthias Springer
Date: 2024-03-04T18:06:00+09:00
New Revision: aaf5c818b3b84a49f0c14879c141c7770f0fcf4b
URL: https://github.com/llvm/llvm-project/commit/aaf5c818b3b84a49f0c14879c141c7770f0fcf4b
DIFF: https://github.com/llvm/llvm-project/commit/aaf5c818b3b84a49f0c14879c141c7770f0fcf4b.diff
LOG: [mlir][Transforms][NFC] Simplify `BlockTypeConversionRewrite` (#83286)
When a block signature is converted during dialect conversion, a
`BlockTypeConversionRewrite` object is stored in the stack of rewrites.
Such an object represents multiple steps:
- Splitting the old block, i.e., creating a new block and moving all
operations over.
- Rewriting block arguments.
- Erasing the old block.
We have dedicated `IRRewrite` objects that represent "creating a block",
"moving an op" and "erasing a block". This commit reuses those rewrite
objects, so that there is less work to do in
`BlockTypeConversionRewrite::rollback` and
`BlockTypeConversionRewrite::commit`/`cleanup`.
Note: This change is in preparation of adding listener support to the
dialect conversion. The less work is done in a `commit` function, the
fewer notifications will have to be sent.
Added:
Modified:
mlir/lib/Transforms/Utils/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7846f1ab56811a..71aee4447f7921 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -441,8 +441,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
void commit() override;
- void cleanup() override;
-
void rollback() override;
private:
@@ -791,24 +789,27 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// block is returned containing the new arguments. Returns `block` if it did
/// not require conversion.
FailureOr<Block *> convertBlockSignature(
- Block *block, const TypeConverter *converter,
+ ConversionPatternRewriter &rewriter, Block *block,
+ const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion = nullptr);
/// Convert the types of non-entry block arguments within the given region.
LogicalResult convertNonEntryRegionTypes(
- Region *region, const TypeConverter &converter,
+ ConversionPatternRewriter &rewriter, Region *region,
+ const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
/// Apply a signature conversion on the given region, using `converter` for
/// materializations if not null.
Block *
- applySignatureConversion(Region *region,
+ applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter);
/// Convert the types of block arguments within the given region.
FailureOr<Block *>
- convertRegionTypes(Region *region, const TypeConverter &converter,
+ convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
+ const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);
/// Apply the given signature conversion on the given block. The new block
@@ -818,7 +819,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// translate between the origin argument types and those specified in the
/// signature conversion.
Block *applySignatureConversion(
- Block *block, const TypeConverter *converter,
+ ConversionPatternRewriter &rewriter, Block *block,
+ const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion);
//===--------------------------------------------------------------------===//
@@ -991,24 +993,8 @@ void BlockTypeConversionRewrite::commit() {
}
}
-void BlockTypeConversionRewrite::cleanup() {
- assert(origBlock->empty() && "expected empty block");
- origBlock->dropAllDefinedValueUses();
- delete origBlock;
- origBlock = nullptr;
-}
-
void BlockTypeConversionRewrite::rollback() {
- // Drop all uses of the new block arguments and replace uses of the new block.
- for (int i = block->getNumArguments() - 1; i >= 0; --i)
- block->getArgument(i).dropAllUses();
block->replaceAllUsesWith(origBlock);
-
- // Move the operations back the original block, move the original block back
- // into its original location and the delete the new block.
- origBlock->getOperations().splice(origBlock->end(), block->getOperations());
- block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
- eraseBlock(block);
}
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
@@ -1224,10 +1210,11 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
// Type Conversion
FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
- Block *block, const TypeConverter *converter,
+ ConversionPatternRewriter &rewriter, Block *block,
+ const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion) {
if (conversion)
- return applySignatureConversion(block, converter, *conversion);
+ return applySignatureConversion(rewriter, block, converter, *conversion);
// If a converter wasn't provided, and the block wasn't already converted,
// there is nothing we can do.
@@ -1236,35 +1223,39 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
// Try to convert the signature for the block with the provided converter.
if (auto conversion = converter->convertBlockSignature(block))
- return applySignatureConversion(block, converter, *conversion);
+ return applySignatureConversion(rewriter, block, converter, *conversion);
return failure();
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
- Region *region, TypeConverter::SignatureConversion &conversion,
+ ConversionPatternRewriter &rewriter, Region *region,
+ TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
if (!region->empty())
- return *convertBlockSignature(®ion->front(), converter, &conversion);
+ return *convertBlockSignature(rewriter, ®ion->front(), converter,
+ &conversion);
return nullptr;
}
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
- Region *region, const TypeConverter &converter,
+ ConversionPatternRewriter &rewriter, Region *region,
+ const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
regionToConverter[region] = &converter;
if (region->empty())
return nullptr;
- if (failed(convertNonEntryRegionTypes(region, converter)))
+ if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
return failure();
- FailureOr<Block *> newEntry =
- convertBlockSignature(®ion->front(), &converter, entryConversion);
+ FailureOr<Block *> newEntry = convertBlockSignature(
+ rewriter, ®ion->front(), &converter, entryConversion);
return newEntry;
}
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
- Region *region, const TypeConverter &converter,
+ ConversionPatternRewriter &rewriter, Region *region,
+ const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
regionToConverter[region] = &converter;
if (region->empty())
@@ -1285,16 +1276,18 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
: const_cast<TypeConverter::SignatureConversion *>(
&blockConversions[blockIdx++]);
- if (failed(convertBlockSignature(&block, &converter, blockConversion)))
+ if (failed(convertBlockSignature(rewriter, &block, &converter,
+ blockConversion)))
return failure();
}
return success();
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
- Block *block, const TypeConverter *converter,
+ ConversionPatternRewriter &rewriter, Block *block,
+ const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
- MLIRContext *ctx = eraseRewriter.getContext();
+ MLIRContext *ctx = rewriter.getContext();
// If no arguments are being changed or added, there is nothing to do.
unsigned origArgCount = block->getNumArguments();
@@ -1304,11 +1297,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// Split the block at the beginning to get a new block to use for the updated
// signature.
- Block *newBlock = block->splitBlock(block->begin());
+ Block *newBlock = rewriter.splitBlock(block, block->begin());
block->replaceAllUsesWith(newBlock);
- // Unlink the block, but do not erase it yet, so that the change can be rolled
- // back.
- block->getParent()->getBlocks().remove(block);
// Map all new arguments to the location of the argument they originate from.
SmallVector<Location> newLocs(convertedTypes.size(),
@@ -1384,6 +1374,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
converter);
+
+ // Erase the old block. (It is just unlinked for now and will be erased during
+ // cleanup.)
+ rewriter.eraseBlock(block);
+
return newBlock;
}
@@ -1592,7 +1587,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->applySignatureConversion(region, conversion, converter);
+ return impl->applySignatureConversion(*this, region, conversion, converter);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1601,7 +1596,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->convertRegionTypes(region, converter, entryConversion);
+ return impl->convertRegionTypes(*this, region, converter, entryConversion);
}
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
@@ -1610,7 +1605,8 @@ LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
+ return impl->convertNonEntryRegionTypes(*this, region, converter,
+ blockConversions);
}
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -2104,7 +2100,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// If the region of the block has a type converter, try to convert the block
// directly.
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
- if (failed(impl.convertBlockSignature(block, converter))) {
+ if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
return failure();
More information about the Mlir-commits
mailing list