[Mlir-commits] [mlir] [mlir][Transforms][NFC] Simplify `ArgConverter` state (PR #81462)

Matthias Springer llvmlistbot at llvm.org
Wed Feb 21 07:44:08 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81462

>From ae28cd9e31ac6133574d0dc78be6454a69ab3830 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 16 Feb 2024 15:11:41 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Simplify `ArgConverter` state

* When converting a block signature, `ArgConverter` creates a new block with the new signature and moves all operation from the old block to the new block. The new block is temporarily inserted into a region that is stored in `regionMapping`. The old block is not yet deleted, so that the conversion can be rolled back. `regionMapping` is not needed. Instead of moving the old block to a temporary region, it can just be unlinked. Block erasures are handles in the same way in the dialect conversion.
* `regionToConverter` is a mapping from regions to type converter. That field is never accessed within `ArgConverter`. It should be stored in `ConversionPatternRewriterImpl` instead.
---
 .../Transforms/Utils/DialectConversion.cpp    | 79 ++++++-------------
 1 file changed, 22 insertions(+), 57 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index cc61bc6b6260c6..88709bb2618744 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -343,23 +343,6 @@ struct ArgConverter {
     const TypeConverter *converter;
   };
 
-  /// Return if the signature of the given block has already been converted.
-  bool hasBeenConverted(Block *block) const {
-    return conversionInfo.count(block) || convertedBlocks.count(block);
-  }
-
-  /// Set the type converter to use for the given region.
-  void setConverter(Region *region, const TypeConverter *typeConverter) {
-    assert(typeConverter && "expected valid type converter");
-    regionToConverter[region] = typeConverter;
-  }
-
-  /// Return the type converter to use for the given region, or null if there
-  /// isn't one.
-  const TypeConverter *getConverter(Region *region) {
-    return regionToConverter.lookup(region);
-  }
-
   //===--------------------------------------------------------------------===//
   // Rewrite Application
   //===--------------------------------------------------------------------===//
@@ -409,24 +392,10 @@ struct ArgConverter {
       ConversionValueMapping &mapping,
       SmallVectorImpl<BlockArgument> &argReplacements);
 
-  /// Insert a new conversion into the cache.
-  void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
-
   /// A collection of blocks that have had their arguments converted. This is a
   /// map from the new replacement block, back to the original block.
   llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
 
-  /// The set of original blocks that were converted.
-  DenseSet<Block *> convertedBlocks;
-
-  /// A mapping from valid regions, to those containing the original blocks of a
-  /// conversion.
-  DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
-
-  /// A mapping of regions to type converters that should be used when
-  /// converting the arguments of blocks within that region.
-  DenseMap<Region *, const TypeConverter *> regionToConverter;
-
   /// The pattern rewriter to use when materializing conversions.
   PatternRewriter &rewriter;
 
@@ -474,12 +443,12 @@ void ArgConverter::discardRewrites(Block *block) {
     block->getArgument(i).dropAllUses();
   block->replaceAllUsesWith(origBlock);
 
-  // Move the operations back the original block and the delete the new block.
+  // Move the operations back the original block, move the original block back
+  // into its original location and the delete the new block.
   origBlock->getOperations().splice(origBlock->end(), block->getOperations());
-  origBlock->moveBefore(block);
+  block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
   block->erase();
 
-  convertedBlocks.erase(origBlock);
   conversionInfo.erase(it);
 }
 
@@ -510,6 +479,9 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
             mapping.lookupOrDefault(castValue, origArg.getType()));
       }
     }
+
+    delete origBlock;
+    blockInfo.origBlock = nullptr;
   }
 }
 
@@ -572,9 +544,11 @@ FailureOr<Block *> ArgConverter::convertSignature(
     Block *block, const TypeConverter *converter,
     ConversionValueMapping &mapping,
     SmallVectorImpl<BlockArgument> &argReplacements) {
-  // Check if the block was already converted. If the block is detached,
-  // conservatively assume it is going to be deleted.
-  if (hasBeenConverted(block) || !block->getParent())
+  // Check if the block was already converted.
+  // * If the block is mapped in `conversionInfo`, it is a converted block.
+  // * If the block is detached, conservatively assume that it is going to be
+  //   deleted; it is likely the old block (before it was converted).
+  if (conversionInfo.count(block) || !block->getParent())
     return block;
   // If a converter wasn't provided, and the block wasn't already converted,
   // there is nothing we can do.
@@ -603,6 +577,9 @@ Block *ArgConverter::applySignatureConversion(
   // signature.
   Block *newBlock = block->splitBlock(block->begin());
   block->replaceAllUsesWith(newBlock);
+  // Unlink the block, but do not erase it yet, so that the change can be rolled
+  // back.
+  block->getParent()->getBlocks().remove(block);
 
   // Map all new arguments to the location of the argument they originate from.
   SmallVector<Location> newLocs(convertedTypes.size(),
@@ -679,24 +656,8 @@ Block *ArgConverter::applySignatureConversion(
         ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
 
-  // Remove the original block from the region and return the new one.
-  insertConversion(newBlock, std::move(info));
-  return newBlock;
-}
-
-void ArgConverter::insertConversion(Block *newBlock,
-                                    ConvertedBlockInfo &&info) {
-  // Get a region to insert the old block.
-  Region *region = newBlock->getParent();
-  std::unique_ptr<Region> &mappedRegion = regionMapping[region];
-  if (!mappedRegion)
-    mappedRegion = std::make_unique<Region>(region->getParentOp());
-
-  // Move the original block to the mapped region and emplace the conversion.
-  mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
-                                   info.origBlock->getIterator());
-  convertedBlocks.insert(info.origBlock);
   conversionInfo.insert({newBlock, std::move(info)});
+  return newBlock;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1227,6 +1188,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// active.
   const TypeConverter *currentTypeConverter = nullptr;
 
+  /// A mapping of regions to type converters that should be used when
+  /// converting the arguments of blocks within that region.
+  DenseMap<Region *, const TypeConverter *> regionToConverter;
+
   /// This allows the user to collect the match failure message.
   function_ref<void(Diagnostic &)> notifyCallback;
 
@@ -1504,7 +1469,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
     Region *region, const TypeConverter &converter,
     TypeConverter::SignatureConversion *entryConversion) {
-  argConverter.setConverter(region, &converter);
+  regionToConverter[region] = &converter;
   if (region->empty())
     return nullptr;
 
@@ -1519,7 +1484,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
 LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
     Region *region, const TypeConverter &converter,
     ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
-  argConverter.setConverter(region, &converter);
+  regionToConverter[region] = &converter;
   if (region->empty())
     return success();
 
@@ -2195,7 +2160,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
 
     // If the region of the block has a type converter, try to convert the block
     // directly.
-    if (auto *converter = impl.argConverter.getConverter(block->getParent())) {
+    if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
       if (failed(impl.convertBlockSignature(block, converter))) {
         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
                                            "block"));



More information about the Mlir-commits mailing list