[Mlir-commits] [mlir] aaf5c81 - [mlir][Transforms][NFC] Simplify `BlockTypeConversionRewrite` (#83286)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 4 01:06:04 PST 2024


Author: Matthias Springer
Date: 2024-03-04T18:06:00+09:00
New Revision: aaf5c818b3b84a49f0c14879c141c7770f0fcf4b

URL: https://github.com/llvm/llvm-project/commit/aaf5c818b3b84a49f0c14879c141c7770f0fcf4b
DIFF: https://github.com/llvm/llvm-project/commit/aaf5c818b3b84a49f0c14879c141c7770f0fcf4b.diff

LOG: [mlir][Transforms][NFC] Simplify `BlockTypeConversionRewrite` (#83286)

When a block signature is converted during dialect conversion, a
`BlockTypeConversionRewrite` object is stored in the stack of rewrites.
Such an object represents multiple steps:
- Splitting the old block, i.e., creating a new block and moving all
operations over.
- Rewriting block arguments.
- Erasing the old block.

We have dedicated `IRRewrite` objects that represent "creating a block",
"moving an op" and "erasing a block". This commit reuses those rewrite
objects, so that there is less work to do in
`BlockTypeConversionRewrite::rollback` and
`BlockTypeConversionRewrite::commit`/`cleanup`.

Note: This change is in preparation of adding listener support to the
dialect conversion. The less work is done in a `commit` function, the
fewer notifications will have to be sent.

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7846f1ab56811a..71aee4447f7921 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -441,8 +441,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
 
   void commit() override;
 
-  void cleanup() override;
-
   void rollback() override;
 
 private:
@@ -791,24 +789,27 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// block is returned containing the new arguments. Returns `block` if it did
   /// not require conversion.
   FailureOr<Block *> convertBlockSignature(
-      Block *block, const TypeConverter *converter,
+      ConversionPatternRewriter &rewriter, Block *block,
+      const TypeConverter *converter,
       TypeConverter::SignatureConversion *conversion = nullptr);
 
   /// Convert the types of non-entry block arguments within the given region.
   LogicalResult convertNonEntryRegionTypes(
-      Region *region, const TypeConverter &converter,
+      ConversionPatternRewriter &rewriter, Region *region,
+      const TypeConverter &converter,
       ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
 
   /// Apply a signature conversion on the given region, using `converter` for
   /// materializations if not null.
   Block *
-  applySignatureConversion(Region *region,
+  applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
                            TypeConverter::SignatureConversion &conversion,
                            const TypeConverter *converter);
 
   /// Convert the types of block arguments within the given region.
   FailureOr<Block *>
-  convertRegionTypes(Region *region, const TypeConverter &converter,
+  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
+                     const TypeConverter &converter,
                      TypeConverter::SignatureConversion *entryConversion);
 
   /// Apply the given signature conversion on the given block. The new block
@@ -818,7 +819,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// translate between the origin argument types and those specified in the
   /// signature conversion.
   Block *applySignatureConversion(
-      Block *block, const TypeConverter *converter,
+      ConversionPatternRewriter &rewriter, Block *block,
+      const TypeConverter *converter,
       TypeConverter::SignatureConversion &signatureConversion);
 
   //===--------------------------------------------------------------------===//
@@ -991,24 +993,8 @@ void BlockTypeConversionRewrite::commit() {
   }
 }
 
-void BlockTypeConversionRewrite::cleanup() {
-  assert(origBlock->empty() && "expected empty block");
-  origBlock->dropAllDefinedValueUses();
-  delete origBlock;
-  origBlock = nullptr;
-}
-
 void BlockTypeConversionRewrite::rollback() {
-  // Drop all uses of the new block arguments and replace uses of the new block.
-  for (int i = block->getNumArguments() - 1; i >= 0; --i)
-    block->getArgument(i).dropAllUses();
   block->replaceAllUsesWith(origBlock);
-
-  // Move the operations back the original block, move the original block back
-  // into its original location and the delete the new block.
-  origBlock->getOperations().splice(origBlock->end(), block->getOperations());
-  block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
-  eraseBlock(block);
 }
 
 LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
@@ -1224,10 +1210,11 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
 // Type Conversion
 
 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
-    Block *block, const TypeConverter *converter,
+    ConversionPatternRewriter &rewriter, Block *block,
+    const TypeConverter *converter,
     TypeConverter::SignatureConversion *conversion) {
   if (conversion)
-    return applySignatureConversion(block, converter, *conversion);
+    return applySignatureConversion(rewriter, block, converter, *conversion);
 
   // If a converter wasn't provided, and the block wasn't already converted,
   // there is nothing we can do.
@@ -1236,35 +1223,39 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
 
   // Try to convert the signature for the block with the provided converter.
   if (auto conversion = converter->convertBlockSignature(block))
-    return applySignatureConversion(block, converter, *conversion);
+    return applySignatureConversion(rewriter, block, converter, *conversion);
   return failure();
 }
 
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
-    Region *region, TypeConverter::SignatureConversion &conversion,
+    ConversionPatternRewriter &rewriter, Region *region,
+    TypeConverter::SignatureConversion &conversion,
     const TypeConverter *converter) {
   if (!region->empty())
-    return *convertBlockSignature(&region->front(), converter, &conversion);
+    return *convertBlockSignature(rewriter, &region->front(), converter,
+                                  &conversion);
   return nullptr;
 }
 
 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
-    Region *region, const TypeConverter &converter,
+    ConversionPatternRewriter &rewriter, Region *region,
+    const TypeConverter &converter,
     TypeConverter::SignatureConversion *entryConversion) {
   regionToConverter[region] = &converter;
   if (region->empty())
     return nullptr;
 
-  if (failed(convertNonEntryRegionTypes(region, converter)))
+  if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
     return failure();
 
-  FailureOr<Block *> newEntry =
-      convertBlockSignature(&region->front(), &converter, entryConversion);
+  FailureOr<Block *> newEntry = convertBlockSignature(
+      rewriter, &region->front(), &converter, entryConversion);
   return newEntry;
 }
 
 LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
-    Region *region, const TypeConverter &converter,
+    ConversionPatternRewriter &rewriter, Region *region,
+    const TypeConverter &converter,
     ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
   regionToConverter[region] = &converter;
   if (region->empty())
@@ -1285,16 +1276,18 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
             : const_cast<TypeConverter::SignatureConversion *>(
                   &blockConversions[blockIdx++]);
 
-    if (failed(convertBlockSignature(&block, &converter, blockConversion)))
+    if (failed(convertBlockSignature(rewriter, &block, &converter,
+                                     blockConversion)))
       return failure();
   }
   return success();
 }
 
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
-    Block *block, const TypeConverter *converter,
+    ConversionPatternRewriter &rewriter, Block *block,
+    const TypeConverter *converter,
     TypeConverter::SignatureConversion &signatureConversion) {
-  MLIRContext *ctx = eraseRewriter.getContext();
+  MLIRContext *ctx = rewriter.getContext();
 
   // If no arguments are being changed or added, there is nothing to do.
   unsigned origArgCount = block->getNumArguments();
@@ -1304,11 +1297,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 
   // Split the block at the beginning to get a new block to use for the updated
   // signature.
-  Block *newBlock = block->splitBlock(block->begin());
+  Block *newBlock = rewriter.splitBlock(block, block->begin());
   block->replaceAllUsesWith(newBlock);
-  // Unlink the block, but do not erase it yet, so that the change can be rolled
-  // back.
-  block->getParent()->getBlocks().remove(block);
 
   // Map all new arguments to the location of the argument they originate from.
   SmallVector<Location> newLocs(convertedTypes.size(),
@@ -1384,6 +1374,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 
   appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
                                             converter);
+
+  // Erase the old block. (It is just unlinked for now and will be erased during
+  // cleanup.)
+  rewriter.eraseBlock(block);
+
   return newBlock;
 }
 
@@ -1592,7 +1587,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
   assert(!impl->wasOpReplaced(region->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->applySignatureConversion(region, conversion, converter);
+  return impl->applySignatureConversion(*this, region, conversion, converter);
 }
 
 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1601,7 +1596,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
   assert(!impl->wasOpReplaced(region->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->convertRegionTypes(region, converter, entryConversion);
+  return impl->convertRegionTypes(*this, region, converter, entryConversion);
 }
 
 LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
@@ -1610,7 +1605,8 @@ LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
   assert(!impl->wasOpReplaced(region->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
+  return impl->convertNonEntryRegionTypes(*this, region, converter,
+                                          blockConversions);
 }
 
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -2104,7 +2100,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     // If the region of the block has a type converter, try to convert the block
     // directly.
     if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
-      if (failed(impl.convertBlockSignature(block, converter))) {
+      if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
                                            "block"));
         return failure();


        


More information about the Mlir-commits mailing list