[Mlir-commits] [mlir] [mlir][Transforms][NFC] Simplify `ArgConverter` state (PR #81462)
Matthias Springer
llvmlistbot at llvm.org
Wed Feb 21 07:44:08 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81462
>From ae28cd9e31ac6133574d0dc78be6454a69ab3830 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 16 Feb 2024 15:11:41 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Simplify `ArgConverter` state
* When converting a block signature, `ArgConverter` creates a new block with the new signature and moves all operation from the old block to the new block. The new block is temporarily inserted into a region that is stored in `regionMapping`. The old block is not yet deleted, so that the conversion can be rolled back. `regionMapping` is not needed. Instead of moving the old block to a temporary region, it can just be unlinked. Block erasures are handles in the same way in the dialect conversion.
* `regionToConverter` is a mapping from regions to type converter. That field is never accessed within `ArgConverter`. It should be stored in `ConversionPatternRewriterImpl` instead.
---
.../Transforms/Utils/DialectConversion.cpp | 79 ++++++-------------
1 file changed, 22 insertions(+), 57 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index cc61bc6b6260c6..88709bb2618744 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -343,23 +343,6 @@ struct ArgConverter {
const TypeConverter *converter;
};
- /// Return if the signature of the given block has already been converted.
- bool hasBeenConverted(Block *block) const {
- return conversionInfo.count(block) || convertedBlocks.count(block);
- }
-
- /// Set the type converter to use for the given region.
- void setConverter(Region *region, const TypeConverter *typeConverter) {
- assert(typeConverter && "expected valid type converter");
- regionToConverter[region] = typeConverter;
- }
-
- /// Return the type converter to use for the given region, or null if there
- /// isn't one.
- const TypeConverter *getConverter(Region *region) {
- return regionToConverter.lookup(region);
- }
-
//===--------------------------------------------------------------------===//
// Rewrite Application
//===--------------------------------------------------------------------===//
@@ -409,24 +392,10 @@ struct ArgConverter {
ConversionValueMapping &mapping,
SmallVectorImpl<BlockArgument> &argReplacements);
- /// Insert a new conversion into the cache.
- void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
-
/// A collection of blocks that have had their arguments converted. This is a
/// map from the new replacement block, back to the original block.
llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
- /// The set of original blocks that were converted.
- DenseSet<Block *> convertedBlocks;
-
- /// A mapping from valid regions, to those containing the original blocks of a
- /// conversion.
- DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
-
- /// A mapping of regions to type converters that should be used when
- /// converting the arguments of blocks within that region.
- DenseMap<Region *, const TypeConverter *> regionToConverter;
-
/// The pattern rewriter to use when materializing conversions.
PatternRewriter &rewriter;
@@ -474,12 +443,12 @@ void ArgConverter::discardRewrites(Block *block) {
block->getArgument(i).dropAllUses();
block->replaceAllUsesWith(origBlock);
- // Move the operations back the original block and the delete the new block.
+ // Move the operations back the original block, move the original block back
+ // into its original location and the delete the new block.
origBlock->getOperations().splice(origBlock->end(), block->getOperations());
- origBlock->moveBefore(block);
+ block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
block->erase();
- convertedBlocks.erase(origBlock);
conversionInfo.erase(it);
}
@@ -510,6 +479,9 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
mapping.lookupOrDefault(castValue, origArg.getType()));
}
}
+
+ delete origBlock;
+ blockInfo.origBlock = nullptr;
}
}
@@ -572,9 +544,11 @@ FailureOr<Block *> ArgConverter::convertSignature(
Block *block, const TypeConverter *converter,
ConversionValueMapping &mapping,
SmallVectorImpl<BlockArgument> &argReplacements) {
- // Check if the block was already converted. If the block is detached,
- // conservatively assume it is going to be deleted.
- if (hasBeenConverted(block) || !block->getParent())
+ // Check if the block was already converted.
+ // * If the block is mapped in `conversionInfo`, it is a converted block.
+ // * If the block is detached, conservatively assume that it is going to be
+ // deleted; it is likely the old block (before it was converted).
+ if (conversionInfo.count(block) || !block->getParent())
return block;
// If a converter wasn't provided, and the block wasn't already converted,
// there is nothing we can do.
@@ -603,6 +577,9 @@ Block *ArgConverter::applySignatureConversion(
// signature.
Block *newBlock = block->splitBlock(block->begin());
block->replaceAllUsesWith(newBlock);
+ // Unlink the block, but do not erase it yet, so that the change can be rolled
+ // back.
+ block->getParent()->getBlocks().remove(block);
// Map all new arguments to the location of the argument they originate from.
SmallVector<Location> newLocs(convertedTypes.size(),
@@ -679,24 +656,8 @@ Block *ArgConverter::applySignatureConversion(
ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
- // Remove the original block from the region and return the new one.
- insertConversion(newBlock, std::move(info));
- return newBlock;
-}
-
-void ArgConverter::insertConversion(Block *newBlock,
- ConvertedBlockInfo &&info) {
- // Get a region to insert the old block.
- Region *region = newBlock->getParent();
- std::unique_ptr<Region> &mappedRegion = regionMapping[region];
- if (!mappedRegion)
- mappedRegion = std::make_unique<Region>(region->getParentOp());
-
- // Move the original block to the mapped region and emplace the conversion.
- mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
- info.origBlock->getIterator());
- convertedBlocks.insert(info.origBlock);
conversionInfo.insert({newBlock, std::move(info)});
+ return newBlock;
}
//===----------------------------------------------------------------------===//
@@ -1227,6 +1188,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// active.
const TypeConverter *currentTypeConverter = nullptr;
+ /// A mapping of regions to type converters that should be used when
+ /// converting the arguments of blocks within that region.
+ DenseMap<Region *, const TypeConverter *> regionToConverter;
+
/// This allows the user to collect the match failure message.
function_ref<void(Diagnostic &)> notifyCallback;
@@ -1504,7 +1469,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
- argConverter.setConverter(region, &converter);
+ regionToConverter[region] = &converter;
if (region->empty())
return nullptr;
@@ -1519,7 +1484,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
- argConverter.setConverter(region, &converter);
+ regionToConverter[region] = &converter;
if (region->empty())
return success();
@@ -2195,7 +2160,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// If the region of the block has a type converter, try to convert the block
// directly.
- if (auto *converter = impl.argConverter.getConverter(block->getParent())) {
+ if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
if (failed(impl.convertBlockSignature(block, converter))) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
More information about the Mlir-commits
mailing list