[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Simplify `ArgConverter` state (PR #81462)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Feb 12 02:56:26 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/81462
* When converting a block signature, `ArgConverter` creates a new block with the new signature and moves all operation from the old block to the new block. The new block is temporarily inserted into a region that is stored in `regionMapping`. The old block is not yet deleted, so that the conversion can be rolled back. `regionMapping` is not needed. Instead of moving the old block to a temporary region, it can just be unlinked. Block erasures are handles in the same way in the dialect conversion.
* `regionToConverter` is a mapping from regions to type converter. That field is never accessed within `ArgConverter`. It should be stored in `ConversionPatternRewriterImpl` instead.
>From 5af1476c5d439c14122ffb50d24923e522a61b32 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 12 Feb 2024 10:53:59 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Simplify `ArgConverter` state
* When converting a block signature, `ArgConverter` creates a new block with the new signature and moves all operation from the old block to the new block. The new block is temporarily inserted into a region that is stored in `regionMapping`. The old block is not yet deleted, so that the conversion can be rolled back. `regionMapping` is not needed. Instead of moving the old block to a temporary region, it can just be unlinked. Block erasures are handles in the same way in the dialect conversion.
* `regionToConverter` is a mapping from regions to type converter. That field is never accessed within `ArgConverter`. It should be stored in `ConversionPatternRewriterImpl` instead.
---
.../Transforms/Utils/DialectConversion.cpp | 63 ++++++-------------
1 file changed, 18 insertions(+), 45 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 489ccd0139c7f..53717f632621d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -348,18 +348,6 @@ struct ArgConverter {
return conversionInfo.count(block) || convertedBlocks.count(block);
}
- /// Set the type converter to use for the given region.
- void setConverter(Region *region, const TypeConverter *typeConverter) {
- assert(typeConverter && "expected valid type converter");
- regionToConverter[region] = typeConverter;
- }
-
- /// Return the type converter to use for the given region, or null if there
- /// isn't one.
- const TypeConverter *getConverter(Region *region) {
- return regionToConverter.lookup(region);
- }
-
//===--------------------------------------------------------------------===//
// Rewrite Application
//===--------------------------------------------------------------------===//
@@ -409,9 +397,6 @@ struct ArgConverter {
ConversionValueMapping &mapping,
SmallVectorImpl<BlockArgument> &argReplacements);
- /// Insert a new conversion into the cache.
- void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
-
/// A collection of blocks that have had their arguments converted. This is a
/// map from the new replacement block, back to the original block.
llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
@@ -419,14 +404,6 @@ struct ArgConverter {
/// The set of original blocks that were converted.
DenseSet<Block *> convertedBlocks;
- /// A mapping from valid regions, to those containing the original blocks of a
- /// conversion.
- DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
-
- /// A mapping of regions to type converters that should be used when
- /// converting the arguments of blocks within that region.
- DenseMap<Region *, const TypeConverter *> regionToConverter;
-
/// The pattern rewriter to use when materializing conversions.
PatternRewriter &rewriter;
@@ -474,9 +451,10 @@ void ArgConverter::discardRewrites(Block *block) {
block->getArgument(i).dropAllUses();
block->replaceAllUsesWith(origBlock);
- // Move the operations back the original block and the delete the new block.
+ // Move the operations back the original block, move the original block back
+ // into its original location and the delete the new block.
origBlock->getOperations().splice(origBlock->end(), block->getOperations());
- origBlock->moveBefore(block);
+ block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
block->erase();
convertedBlocks.erase(origBlock);
@@ -510,6 +488,9 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
mapping.lookupOrDefault(castValue, origArg.getType()));
}
}
+
+ delete origBlock;
+ blockInfo.origBlock = nullptr;
}
}
@@ -603,6 +584,9 @@ Block *ArgConverter::applySignatureConversion(
// signature.
Block *newBlock = block->splitBlock(block->begin());
block->replaceAllUsesWith(newBlock);
+ // Unlink the block, but do not erase it yet, so that the change can be rolled
+ // back.
+ block->getParent()->getBlocks().remove(block);
// Map all new arguments to the location of the argument they originate from.
SmallVector<Location> newLocs(convertedTypes.size(),
@@ -679,24 +663,9 @@ Block *ArgConverter::applySignatureConversion(
ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
- // Remove the original block from the region and return the new one.
- insertConversion(newBlock, std::move(info));
- return newBlock;
-}
-
-void ArgConverter::insertConversion(Block *newBlock,
- ConvertedBlockInfo &&info) {
- // Get a region to insert the old block.
- Region *region = newBlock->getParent();
- std::unique_ptr<Region> &mappedRegion = regionMapping[region];
- if (!mappedRegion)
- mappedRegion = std::make_unique<Region>(region->getParentOp());
-
- // Move the original block to the mapped region and emplace the conversion.
- mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
- info.origBlock->getIterator());
- convertedBlocks.insert(info.origBlock);
+ convertedBlocks.insert(block);
conversionInfo.insert({newBlock, std::move(info)});
+ return newBlock;
}
//===----------------------------------------------------------------------===//
@@ -1182,6 +1151,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// active.
const TypeConverter *currentTypeConverter = nullptr;
+ /// A mapping of regions to type converters that should be used when
+ /// converting the arguments of blocks within that region.
+ DenseMap<Region *, const TypeConverter *> regionToConverter;
+
/// This allows the user to collect the match failure message.
function_ref<void(Diagnostic &)> notifyCallback;
@@ -1459,7 +1432,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
- argConverter.setConverter(region, &converter);
+ regionToConverter[region] = &converter;
if (region->empty())
return nullptr;
@@ -1474,7 +1447,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
- argConverter.setConverter(region, &converter);
+ regionToConverter[region] = &converter;
if (region->empty())
return success();
@@ -2154,7 +2127,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// If the region of the block has a type converter, try to convert the block
// directly.
- if (auto *converter = impl.argConverter.getConverter(block->getParent())) {
+ if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
if (failed(impl.convertBlockSignature(block, converter))) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
More information about the llvm-branch-commits
mailing list