[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Simplify `BlockTypeConversionRewrite` (PR #83286)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Feb 28 08:35:27 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
When a block signature is converted during dialect conversion, a `BlockTypeConversionRewrite` object is stored in the stack of rewrites. Such an object represents multiple steps:
- Splitting the old block, i.e., creating a new block and moving all operations over.
- Rewriting block arguments.
- Erasing the old block.
We have dedicated `IRRewrite` objects that represent "creating a block", "moving an op" and "erasing a block". This commit reuses those rewrite objects, so that there is less work to do in `BlockTypeConversionRewrite::rollback` and `BlockTypeConversionRewrite::commit`.
---
Full diff: https://github.com/llvm/llvm-project/pull/83286.diff
1 Files Affected:
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+40-40)
``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b81495a95c80ed..cac990d498d7d3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -746,24 +746,27 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// block is returned containing the new arguments. Returns `block` if it did
/// not require conversion.
FailureOr<Block *> convertBlockSignature(
- Block *block, const TypeConverter *converter,
+ ConversionPatternRewriter &rewriter, Block *block,
+ const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion = nullptr);
/// Convert the types of non-entry block arguments within the given region.
LogicalResult convertNonEntryRegionTypes(
- Region *region, const TypeConverter &converter,
+ ConversionPatternRewriter &rewriter, Region *region,
+ const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
/// Apply a signature conversion on the given region, using `converter` for
/// materializations if not null.
Block *
- applySignatureConversion(Region *region,
+ applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter);
/// Convert the types of block arguments within the given region.
FailureOr<Block *>
- convertRegionTypes(Region *region, const TypeConverter &converter,
+ convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
+ const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);
/// Apply the given signature conversion on the given block. The new block
@@ -773,7 +776,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// translate between the origin argument types and those specified in the
/// signature conversion.
Block *applySignatureConversion(
- Block *block, const TypeConverter *converter,
+ ConversionPatternRewriter &rewriter, Block *block,
+ const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion);
//===--------------------------------------------------------------------===//
@@ -940,24 +944,10 @@ void BlockTypeConversionRewrite::commit() {
rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
}
}
-
- assert(origBlock->empty() && "expected empty block");
- origBlock->dropAllDefinedValueUses();
- delete origBlock;
- origBlock = nullptr;
}
void BlockTypeConversionRewrite::rollback() {
- // Drop all uses of the new block arguments and replace uses of the new block.
- for (int i = block->getNumArguments() - 1; i >= 0; --i)
- block->getArgument(i).dropAllUses();
block->replaceAllUsesWith(origBlock);
-
- // Move the operations back the original block, move the original block back
- // into its original location and the delete the new block.
- origBlock->getOperations().splice(origBlock->end(), block->getOperations());
- block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
- eraseBlock(block);
}
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
@@ -1173,10 +1163,11 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
// Type Conversion
FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
- Block *block, const TypeConverter *converter,
+ ConversionPatternRewriter &rewriter, Block *block,
+ const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion) {
if (conversion)
- return applySignatureConversion(block, converter, *conversion);
+ return applySignatureConversion(rewriter, block, converter, *conversion);
// If a converter wasn't provided, and the block wasn't already converted,
// there is nothing we can do.
@@ -1185,35 +1176,39 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
// Try to convert the signature for the block with the provided converter.
if (auto conversion = converter->convertBlockSignature(block))
- return applySignatureConversion(block, converter, *conversion);
+ return applySignatureConversion(rewriter, block, converter, *conversion);
return failure();
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
- Region *region, TypeConverter::SignatureConversion &conversion,
+ ConversionPatternRewriter &rewriter, Region *region,
+ TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
if (!region->empty())
- return *convertBlockSignature(®ion->front(), converter, &conversion);
+ return *convertBlockSignature(rewriter, ®ion->front(), converter,
+ &conversion);
return nullptr;
}
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
- Region *region, const TypeConverter &converter,
+ ConversionPatternRewriter &rewriter, Region *region,
+ const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
regionToConverter[region] = &converter;
if (region->empty())
return nullptr;
- if (failed(convertNonEntryRegionTypes(region, converter)))
+ if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
return failure();
- FailureOr<Block *> newEntry =
- convertBlockSignature(®ion->front(), &converter, entryConversion);
+ FailureOr<Block *> newEntry = convertBlockSignature(
+ rewriter, ®ion->front(), &converter, entryConversion);
return newEntry;
}
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
- Region *region, const TypeConverter &converter,
+ ConversionPatternRewriter &rewriter, Region *region,
+ const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
regionToConverter[region] = &converter;
if (region->empty())
@@ -1234,16 +1229,18 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
: const_cast<TypeConverter::SignatureConversion *>(
&blockConversions[blockIdx++]);
- if (failed(convertBlockSignature(&block, &converter, blockConversion)))
+ if (failed(convertBlockSignature(rewriter, &block, &converter,
+ blockConversion)))
return failure();
}
return success();
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
- Block *block, const TypeConverter *converter,
+ ConversionPatternRewriter &rewriter, Block *block,
+ const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
- MLIRContext *ctx = eraseRewriter.getContext();
+ MLIRContext *ctx = rewriter.getContext();
// If no arguments are being changed or added, there is nothing to do.
unsigned origArgCount = block->getNumArguments();
@@ -1253,11 +1250,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// Split the block at the beginning to get a new block to use for the updated
// signature.
- Block *newBlock = block->splitBlock(block->begin());
+ Block *newBlock = rewriter.splitBlock(block, block->begin());
block->replaceAllUsesWith(newBlock);
- // Unlink the block, but do not erase it yet, so that the change can be rolled
- // back.
- block->getParent()->getBlocks().remove(block);
// Map all new arguments to the location of the argument they originate from.
SmallVector<Location> newLocs(convertedTypes.size(),
@@ -1333,6 +1327,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
converter);
+
+ // Erase the old block. (It is just unlinked for now and will be erased during
+ // cleanup.)
+ rewriter.eraseBlock(block);
+
return newBlock;
}
@@ -1531,7 +1530,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->applySignatureConversion(region, conversion, converter);
+ return impl->applySignatureConversion(*this, region, conversion, converter);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1540,7 +1539,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->convertRegionTypes(region, converter, entryConversion);
+ return impl->convertRegionTypes(*this, region, converter, entryConversion);
}
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
@@ -1549,7 +1548,8 @@ LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
+ return impl->convertNonEntryRegionTypes(*this, region, converter,
+ blockConversions);
}
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -2051,7 +2051,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// If the region of the block has a type converter, try to convert the block
// directly.
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
- if (failed(impl.convertBlockSignature(block, converter))) {
+ if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
return failure();
``````````
</details>
https://github.com/llvm/llvm-project/pull/83286
More information about the llvm-branch-commits
mailing list