[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Simplify `BlockTypeConversionRewrite` (PR #83286)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Feb 28 08:35:27 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

When a block signature is converted during dialect conversion, a `BlockTypeConversionRewrite` object is stored in the stack of rewrites. Such an object represents multiple steps:
- Splitting the old block, i.e., creating a new block and moving all operations over.
- Rewriting block arguments.
- Erasing the old block.

We have dedicated `IRRewrite` objects that represent "creating a block", "moving an op" and "erasing a block". This commit reuses those rewrite objects, so that there is less work to do in  `BlockTypeConversionRewrite::rollback` and `BlockTypeConversionRewrite::commit`.

---
Full diff: https://github.com/llvm/llvm-project/pull/83286.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+40-40) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b81495a95c80ed..cac990d498d7d3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -746,24 +746,27 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// block is returned containing the new arguments. Returns `block` if it did
   /// not require conversion.
   FailureOr<Block *> convertBlockSignature(
-      Block *block, const TypeConverter *converter,
+      ConversionPatternRewriter &rewriter, Block *block,
+      const TypeConverter *converter,
       TypeConverter::SignatureConversion *conversion = nullptr);
 
   /// Convert the types of non-entry block arguments within the given region.
   LogicalResult convertNonEntryRegionTypes(
-      Region *region, const TypeConverter &converter,
+      ConversionPatternRewriter &rewriter, Region *region,
+      const TypeConverter &converter,
       ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
 
   /// Apply a signature conversion on the given region, using `converter` for
   /// materializations if not null.
   Block *
-  applySignatureConversion(Region *region,
+  applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
                            TypeConverter::SignatureConversion &conversion,
                            const TypeConverter *converter);
 
   /// Convert the types of block arguments within the given region.
   FailureOr<Block *>
-  convertRegionTypes(Region *region, const TypeConverter &converter,
+  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
+                     const TypeConverter &converter,
                      TypeConverter::SignatureConversion *entryConversion);
 
   /// Apply the given signature conversion on the given block. The new block
@@ -773,7 +776,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// translate between the origin argument types and those specified in the
   /// signature conversion.
   Block *applySignatureConversion(
-      Block *block, const TypeConverter *converter,
+      ConversionPatternRewriter &rewriter, Block *block,
+      const TypeConverter *converter,
       TypeConverter::SignatureConversion &signatureConversion);
 
   //===--------------------------------------------------------------------===//
@@ -940,24 +944,10 @@ void BlockTypeConversionRewrite::commit() {
           rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
     }
   }
-
-  assert(origBlock->empty() && "expected empty block");
-  origBlock->dropAllDefinedValueUses();
-  delete origBlock;
-  origBlock = nullptr;
 }
 
 void BlockTypeConversionRewrite::rollback() {
-  // Drop all uses of the new block arguments and replace uses of the new block.
-  for (int i = block->getNumArguments() - 1; i >= 0; --i)
-    block->getArgument(i).dropAllUses();
   block->replaceAllUsesWith(origBlock);
-
-  // Move the operations back the original block, move the original block back
-  // into its original location and the delete the new block.
-  origBlock->getOperations().splice(origBlock->end(), block->getOperations());
-  block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
-  eraseBlock(block);
 }
 
 LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
@@ -1173,10 +1163,11 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
 // Type Conversion
 
 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
-    Block *block, const TypeConverter *converter,
+    ConversionPatternRewriter &rewriter, Block *block,
+    const TypeConverter *converter,
     TypeConverter::SignatureConversion *conversion) {
   if (conversion)
-    return applySignatureConversion(block, converter, *conversion);
+    return applySignatureConversion(rewriter, block, converter, *conversion);
 
   // If a converter wasn't provided, and the block wasn't already converted,
   // there is nothing we can do.
@@ -1185,35 +1176,39 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
 
   // Try to convert the signature for the block with the provided converter.
   if (auto conversion = converter->convertBlockSignature(block))
-    return applySignatureConversion(block, converter, *conversion);
+    return applySignatureConversion(rewriter, block, converter, *conversion);
   return failure();
 }
 
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
-    Region *region, TypeConverter::SignatureConversion &conversion,
+    ConversionPatternRewriter &rewriter, Region *region,
+    TypeConverter::SignatureConversion &conversion,
     const TypeConverter *converter) {
   if (!region->empty())
-    return *convertBlockSignature(&region->front(), converter, &conversion);
+    return *convertBlockSignature(rewriter, &region->front(), converter,
+                                  &conversion);
   return nullptr;
 }
 
 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
-    Region *region, const TypeConverter &converter,
+    ConversionPatternRewriter &rewriter, Region *region,
+    const TypeConverter &converter,
     TypeConverter::SignatureConversion *entryConversion) {
   regionToConverter[region] = &converter;
   if (region->empty())
     return nullptr;
 
-  if (failed(convertNonEntryRegionTypes(region, converter)))
+  if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
     return failure();
 
-  FailureOr<Block *> newEntry =
-      convertBlockSignature(&region->front(), &converter, entryConversion);
+  FailureOr<Block *> newEntry = convertBlockSignature(
+      rewriter, &region->front(), &converter, entryConversion);
   return newEntry;
 }
 
 LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
-    Region *region, const TypeConverter &converter,
+    ConversionPatternRewriter &rewriter, Region *region,
+    const TypeConverter &converter,
     ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
   regionToConverter[region] = &converter;
   if (region->empty())
@@ -1234,16 +1229,18 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
             : const_cast<TypeConverter::SignatureConversion *>(
                   &blockConversions[blockIdx++]);
 
-    if (failed(convertBlockSignature(&block, &converter, blockConversion)))
+    if (failed(convertBlockSignature(rewriter, &block, &converter,
+                                     blockConversion)))
       return failure();
   }
   return success();
 }
 
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
-    Block *block, const TypeConverter *converter,
+    ConversionPatternRewriter &rewriter, Block *block,
+    const TypeConverter *converter,
     TypeConverter::SignatureConversion &signatureConversion) {
-  MLIRContext *ctx = eraseRewriter.getContext();
+  MLIRContext *ctx = rewriter.getContext();
 
   // If no arguments are being changed or added, there is nothing to do.
   unsigned origArgCount = block->getNumArguments();
@@ -1253,11 +1250,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 
   // Split the block at the beginning to get a new block to use for the updated
   // signature.
-  Block *newBlock = block->splitBlock(block->begin());
+  Block *newBlock = rewriter.splitBlock(block, block->begin());
   block->replaceAllUsesWith(newBlock);
-  // Unlink the block, but do not erase it yet, so that the change can be rolled
-  // back.
-  block->getParent()->getBlocks().remove(block);
 
   // Map all new arguments to the location of the argument they originate from.
   SmallVector<Location> newLocs(convertedTypes.size(),
@@ -1333,6 +1327,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 
   appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
                                             converter);
+
+  // Erase the old block. (It is just unlinked for now and will be erased during
+  // cleanup.)
+  rewriter.eraseBlock(block);
+
   return newBlock;
 }
 
@@ -1531,7 +1530,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
   assert(!impl->wasOpReplaced(region->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->applySignatureConversion(region, conversion, converter);
+  return impl->applySignatureConversion(*this, region, conversion, converter);
 }
 
 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1540,7 +1539,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
   assert(!impl->wasOpReplaced(region->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->convertRegionTypes(region, converter, entryConversion);
+  return impl->convertRegionTypes(*this, region, converter, entryConversion);
 }
 
 LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
@@ -1549,7 +1548,8 @@ LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
   assert(!impl->wasOpReplaced(region->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
+  return impl->convertNonEntryRegionTypes(*this, region, converter,
+                                          blockConversions);
 }
 
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -2051,7 +2051,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     // If the region of the block has a type converter, try to convert the block
     // directly.
     if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
-      if (failed(impl.convertBlockSignature(block, converter))) {
+      if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
                                            "block"));
         return failure();

``````````

</details>


https://github.com/llvm/llvm-project/pull/83286


More information about the llvm-branch-commits mailing list