[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Simplify `ArgConverter` state (PR #81462)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Feb 12 02:56:56 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

* When converting a block signature, `ArgConverter` creates a new block with the new signature and moves all operation from the old block to the new block. The new block is temporarily inserted into a region that is stored in `regionMapping`. The old block is not yet deleted, so that the conversion can be rolled back. `regionMapping` is not needed. Instead of moving the old block to a temporary region, it can just be unlinked. Block erasures are handles in the same way in the dialect conversion.
* `regionToConverter` is a mapping from regions to type converter. That field is never accessed within `ArgConverter`. It should be stored in `ConversionPatternRewriterImpl` instead.


---
Full diff: https://github.com/llvm/llvm-project/pull/81462.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+18-45) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 489ccd0139c7f2..53717f632621dd 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -348,18 +348,6 @@ struct ArgConverter {
     return conversionInfo.count(block) || convertedBlocks.count(block);
   }
 
-  /// Set the type converter to use for the given region.
-  void setConverter(Region *region, const TypeConverter *typeConverter) {
-    assert(typeConverter && "expected valid type converter");
-    regionToConverter[region] = typeConverter;
-  }
-
-  /// Return the type converter to use for the given region, or null if there
-  /// isn't one.
-  const TypeConverter *getConverter(Region *region) {
-    return regionToConverter.lookup(region);
-  }
-
   //===--------------------------------------------------------------------===//
   // Rewrite Application
   //===--------------------------------------------------------------------===//
@@ -409,9 +397,6 @@ struct ArgConverter {
       ConversionValueMapping &mapping,
       SmallVectorImpl<BlockArgument> &argReplacements);
 
-  /// Insert a new conversion into the cache.
-  void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
-
   /// A collection of blocks that have had their arguments converted. This is a
   /// map from the new replacement block, back to the original block.
   llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
@@ -419,14 +404,6 @@ struct ArgConverter {
   /// The set of original blocks that were converted.
   DenseSet<Block *> convertedBlocks;
 
-  /// A mapping from valid regions, to those containing the original blocks of a
-  /// conversion.
-  DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
-
-  /// A mapping of regions to type converters that should be used when
-  /// converting the arguments of blocks within that region.
-  DenseMap<Region *, const TypeConverter *> regionToConverter;
-
   /// The pattern rewriter to use when materializing conversions.
   PatternRewriter &rewriter;
 
@@ -474,9 +451,10 @@ void ArgConverter::discardRewrites(Block *block) {
     block->getArgument(i).dropAllUses();
   block->replaceAllUsesWith(origBlock);
 
-  // Move the operations back the original block and the delete the new block.
+  // Move the operations back the original block, move the original block back
+  // into its original location and the delete the new block.
   origBlock->getOperations().splice(origBlock->end(), block->getOperations());
-  origBlock->moveBefore(block);
+  block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
   block->erase();
 
   convertedBlocks.erase(origBlock);
@@ -510,6 +488,9 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
             mapping.lookupOrDefault(castValue, origArg.getType()));
       }
     }
+
+    delete origBlock;
+    blockInfo.origBlock = nullptr;
   }
 }
 
@@ -603,6 +584,9 @@ Block *ArgConverter::applySignatureConversion(
   // signature.
   Block *newBlock = block->splitBlock(block->begin());
   block->replaceAllUsesWith(newBlock);
+  // Unlink the block, but do not erase it yet, so that the change can be rolled
+  // back.
+  block->getParent()->getBlocks().remove(block);
 
   // Map all new arguments to the location of the argument they originate from.
   SmallVector<Location> newLocs(convertedTypes.size(),
@@ -679,24 +663,9 @@ Block *ArgConverter::applySignatureConversion(
         ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
 
-  // Remove the original block from the region and return the new one.
-  insertConversion(newBlock, std::move(info));
-  return newBlock;
-}
-
-void ArgConverter::insertConversion(Block *newBlock,
-                                    ConvertedBlockInfo &&info) {
-  // Get a region to insert the old block.
-  Region *region = newBlock->getParent();
-  std::unique_ptr<Region> &mappedRegion = regionMapping[region];
-  if (!mappedRegion)
-    mappedRegion = std::make_unique<Region>(region->getParentOp());
-
-  // Move the original block to the mapped region and emplace the conversion.
-  mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
-                                   info.origBlock->getIterator());
-  convertedBlocks.insert(info.origBlock);
+  convertedBlocks.insert(block);
   conversionInfo.insert({newBlock, std::move(info)});
+  return newBlock;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1182,6 +1151,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// active.
   const TypeConverter *currentTypeConverter = nullptr;
 
+  /// A mapping of regions to type converters that should be used when
+  /// converting the arguments of blocks within that region.
+  DenseMap<Region *, const TypeConverter *> regionToConverter;
+
   /// This allows the user to collect the match failure message.
   function_ref<void(Diagnostic &)> notifyCallback;
 
@@ -1459,7 +1432,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
     Region *region, const TypeConverter &converter,
     TypeConverter::SignatureConversion *entryConversion) {
-  argConverter.setConverter(region, &converter);
+  regionToConverter[region] = &converter;
   if (region->empty())
     return nullptr;
 
@@ -1474,7 +1447,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
 LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
     Region *region, const TypeConverter &converter,
     ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
-  argConverter.setConverter(region, &converter);
+  regionToConverter[region] = &converter;
   if (region->empty())
     return success();
 
@@ -2154,7 +2127,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
 
     // If the region of the block has a type converter, try to convert the block
     // directly.
-    if (auto *converter = impl.argConverter.getConverter(block->getParent())) {
+    if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
       if (failed(impl.convertBlockSignature(block, converter))) {
         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
                                            "block"));

``````````

</details>


https://github.com/llvm/llvm-project/pull/81462


More information about the llvm-branch-commits mailing list