[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Turn block type conversion into `IRRewrite` (PR #81756)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Feb 14 08:25:26 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/81756

This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure).

Until now, the signature conversion of a block was only a "partial" IR rewrite. Rollbacks were triggered via `BlockTypeConversionRewrite::rollback`, but there was no `BlockTypeConversionRewrite::commit` equivalent.

Overview of changes:
* Remove `ArgConverter`, an internal helper class that kept track of all block type conversions. There is now a separate `BlockTypeConversionRewrite` for each block type conversion.
* No more special handling for block type conversions. They are now normal "IR rewrites", just like "block creation" or "block movement". In particular, trigger "commits" of block type conversion via `BlockTypeConversionRewrite::commit`.
* Remove `ArgConverter::notifyOpRemoved`. This function was used to inform the `ArgConverter` that an operation was erased, to prevent a double-free of operations in certain situations. It would be unpractical to add a `notifyOpRemoved` API to `IRRewrite`. Instead, erasing ops/block should go through a new `SingleEraseRewriter` (that is owned by the `ConversionPatternRewriterImpl`) if there is chance of double-free. This rewriter ignores `eraseOp`/`eraseBlock` if the op/block was already freed.


>From ae08e916df9fcee362a813b89c126716376e462f Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 14 Feb 2024 16:23:29 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Turn block type convertion into
 `IRRewrite`

This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure).

Until now, the signature conversion of a block was only a "partial" IR rewrite. Rollbacks were triggered via `BlockTypeConversionRewrite::rollback`, but there was no `BlockTypeConversionRewrite::commit` equivalent.

Overview of changes:
* Remove `ArgConverter`, an internal helper class that kept track of all block type conversions. There is now a separate `BlockTypeConversionRewrite` for each block type conversion.
* No more special handling for block type conversions. They are now normal "IR rewrites", just like "block creation" or "block movement". In particular, trigger "commits" of block type conversion via `BlockTypeConversionRewrite::commit`.
* Remove `ArgConverter::notifyOpRemoved`. This function was used to inform the `ArgConverter` that an operation was erased, to prevent a double-free of operations in certain situations. It would be unpractical to add a `notifyOpRemoved` API to `IRRewrite`. Instead, erasing ops/block should go through a new `SingleEraseRewriter` (that is owned by the `ConversionPatternRewriterImpl`) if there is chance of double-free. This rewriter ignores `eraseOp`/`eraseBlock` if the op/block was already freed.
---
 .../Transforms/Utils/DialectConversion.cpp    | 810 ++++++++----------
 1 file changed, 376 insertions(+), 434 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 67b076b295eae8..b2baa88879b6e9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,12 +154,13 @@ namespace {
 struct RewriterState {
   RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
                 unsigned numReplacements, unsigned numArgReplacements,
-                unsigned numRewrites, unsigned numIgnoredOperations)
+                unsigned numRewrites, unsigned numIgnoredOperations,
+                unsigned numErased)
       : numCreatedOps(numCreatedOps),
         numUnresolvedMaterializations(numUnresolvedMaterializations),
         numReplacements(numReplacements),
         numArgReplacements(numArgReplacements), numRewrites(numRewrites),
-        numIgnoredOperations(numIgnoredOperations) {}
+        numIgnoredOperations(numIgnoredOperations), numErased(numErased) {}
 
   /// The current number of created operations.
   unsigned numCreatedOps;
@@ -178,6 +179,9 @@ struct RewriterState {
 
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
+
+  /// The current number of erased operations/blocks.
+  unsigned numErased;
 };
 
 //===----------------------------------------------------------------------===//
@@ -292,374 +296,6 @@ static Value buildUnresolvedTargetMaterialization(
       outputType, outputType, converter, unresolvedMaterializations);
 }
 
-//===----------------------------------------------------------------------===//
-// ArgConverter
-//===----------------------------------------------------------------------===//
-namespace {
-/// This class provides a simple interface for converting the types of block
-/// arguments. This is done by creating a new block that contains the new legal
-/// types and extracting the block that contains the old illegal types to allow
-/// for undoing pending rewrites in the case of failure.
-struct ArgConverter {
-  ArgConverter(
-      PatternRewriter &rewriter,
-      SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations)
-      : rewriter(rewriter),
-        unresolvedMaterializations(unresolvedMaterializations) {}
-
-  /// This structure contains the information pertaining to an argument that has
-  /// been converted.
-  struct ConvertedArgInfo {
-    ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
-                     Value castValue = nullptr)
-        : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
-    /// The start index of in the new argument list that contains arguments that
-    /// replace the original.
-    unsigned newArgIdx;
-
-    /// The number of arguments that replaced the original argument.
-    unsigned newArgSize;
-
-    /// The cast value that was created to cast from the new arguments to the
-    /// old. This only used if 'newArgSize' > 1.
-    Value castValue;
-  };
-
-  /// This structure contains information pertaining to a block that has had its
-  /// signature converted.
-  struct ConvertedBlockInfo {
-    ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter)
-        : origBlock(origBlock), converter(converter) {}
-
-    /// The original block that was requested to have its signature converted.
-    Block *origBlock;
-
-    /// The conversion information for each of the arguments. The information is
-    /// std::nullopt if the argument was dropped during conversion.
-    SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
-    /// The type converter used to convert the arguments.
-    const TypeConverter *converter;
-  };
-
-  //===--------------------------------------------------------------------===//
-  // Rewrite Application
-  //===--------------------------------------------------------------------===//
-
-  /// Erase any rewrites registered for the blocks within the given operation
-  /// which is about to be removed. This merely drops the rewrites without
-  /// undoing them.
-  void notifyOpRemoved(Operation *op);
-
-  /// Cleanup and undo any generated conversions for the arguments of block.
-  /// This method replaces the new block with the original, reverting the IR to
-  /// its original state.
-  void discardRewrites(Block *block);
-
-  /// Fully replace uses of the old arguments with the new.
-  void applyRewrites(ConversionValueMapping &mapping);
-
-  /// Materialize any necessary conversions for converted arguments that have
-  /// live users, using the provided `findLiveUser` to search for a user that
-  /// survives the conversion process.
-  LogicalResult
-  materializeLiveConversions(ConversionValueMapping &mapping,
-                             OpBuilder &builder,
-                             function_ref<Operation *(Value)> findLiveUser);
-
-  //===--------------------------------------------------------------------===//
-  // Conversion
-  //===--------------------------------------------------------------------===//
-
-  /// Attempt to convert the signature of the given block, if successful a new
-  /// block is returned containing the new arguments. Returns `block` if it did
-  /// not require conversion.
-  FailureOr<Block *>
-  convertSignature(Block *block, const TypeConverter *converter,
-                   ConversionValueMapping &mapping,
-                   SmallVectorImpl<BlockArgument> &argReplacements);
-
-  /// Apply the given signature conversion on the given block. The new block
-  /// containing the updated signature is returned. If no conversions were
-  /// necessary, e.g. if the block has no arguments, `block` is returned.
-  /// `converter` is used to generate any necessary cast operations that
-  /// translate between the origin argument types and those specified in the
-  /// signature conversion.
-  Block *applySignatureConversion(
-      Block *block, const TypeConverter *converter,
-      TypeConverter::SignatureConversion &signatureConversion,
-      ConversionValueMapping &mapping,
-      SmallVectorImpl<BlockArgument> &argReplacements);
-
-  /// A collection of blocks that have had their arguments converted. This is a
-  /// map from the new replacement block, back to the original block.
-  llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
-
-  /// The pattern rewriter to use when materializing conversions.
-  PatternRewriter &rewriter;
-
-  /// An ordered set of unresolved materializations during conversion.
-  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations;
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// Rewrite Application
-
-void ArgConverter::notifyOpRemoved(Operation *op) {
-  if (conversionInfo.empty())
-    return;
-
-  for (Region &region : op->getRegions()) {
-    for (Block &block : region) {
-      // Drop any rewrites from within.
-      for (Operation &nestedOp : block)
-        if (nestedOp.getNumRegions())
-          notifyOpRemoved(&nestedOp);
-
-      // Check if this block was converted.
-      auto it = conversionInfo.find(&block);
-      if (it == conversionInfo.end())
-        continue;
-
-      // Drop all uses of the original arguments and delete the original block.
-      Block *origBlock = it->second.origBlock;
-      for (BlockArgument arg : origBlock->getArguments())
-        arg.dropAllUses();
-      conversionInfo.erase(it);
-    }
-  }
-}
-
-void ArgConverter::discardRewrites(Block *block) {
-  auto it = conversionInfo.find(block);
-  if (it == conversionInfo.end())
-    return;
-  Block *origBlock = it->second.origBlock;
-
-  // Drop all uses of the new block arguments and replace uses of the new block.
-  for (int i = block->getNumArguments() - 1; i >= 0; --i)
-    block->getArgument(i).dropAllUses();
-  block->replaceAllUsesWith(origBlock);
-
-  // Move the operations back the original block, move the original block back
-  // into its original location and the delete the new block.
-  origBlock->getOperations().splice(origBlock->end(), block->getOperations());
-  block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
-  block->erase();
-
-  conversionInfo.erase(it);
-}
-
-void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
-  for (auto &info : conversionInfo) {
-    ConvertedBlockInfo &blockInfo = info.second;
-    Block *origBlock = blockInfo.origBlock;
-
-    // Process the remapping for each of the original arguments.
-    for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
-      std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
-      BlockArgument origArg = origBlock->getArgument(i);
-
-      // Handle the case of a 1->0 value mapping.
-      if (!argInfo) {
-        if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType()))
-          origArg.replaceAllUsesWith(newArg);
-        continue;
-      }
-
-      // Otherwise this is a 1->1+ value mapping.
-      Value castValue = argInfo->castValue;
-      assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
-      // If the argument is still used, replace it with the generated cast.
-      if (!origArg.use_empty()) {
-        origArg.replaceAllUsesWith(
-            mapping.lookupOrDefault(castValue, origArg.getType()));
-      }
-    }
-
-    delete origBlock;
-    blockInfo.origBlock = nullptr;
-  }
-}
-
-LogicalResult ArgConverter::materializeLiveConversions(
-    ConversionValueMapping &mapping, OpBuilder &builder,
-    function_ref<Operation *(Value)> findLiveUser) {
-  for (auto &info : conversionInfo) {
-    Block *newBlock = info.first;
-    ConvertedBlockInfo &blockInfo = info.second;
-    Block *origBlock = blockInfo.origBlock;
-
-    // Process the remapping for each of the original arguments.
-    for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
-      // If the type of this argument changed and the argument is still live, we
-      // need to materialize a conversion.
-      BlockArgument origArg = origBlock->getArgument(i);
-      if (mapping.lookupOrNull(origArg, origArg.getType()))
-        continue;
-      Operation *liveUser = findLiveUser(origArg);
-      if (!liveUser)
-        continue;
-
-      Value replacementValue = mapping.lookupOrDefault(origArg);
-      bool isDroppedArg = replacementValue == origArg;
-      if (isDroppedArg)
-        rewriter.setInsertionPointToStart(newBlock);
-      else
-        rewriter.setInsertionPointAfterValue(replacementValue);
-      Value newArg;
-      if (blockInfo.converter) {
-        newArg = blockInfo.converter->materializeSourceConversion(
-            rewriter, origArg.getLoc(), origArg.getType(),
-            isDroppedArg ? ValueRange() : ValueRange(replacementValue));
-        assert((!newArg || newArg.getType() == origArg.getType()) &&
-               "materialization hook did not provide a value of the expected "
-               "type");
-      }
-      if (!newArg) {
-        InFlightDiagnostic diag =
-            emitError(origArg.getLoc())
-            << "failed to materialize conversion for block argument #" << i
-            << " that remained live after conversion, type was "
-            << origArg.getType();
-        if (!isDroppedArg)
-          diag << ", with target type " << replacementValue.getType();
-        diag.attachNote(liveUser->getLoc())
-            << "see existing live user here: " << *liveUser;
-        return failure();
-      }
-      mapping.map(origArg, newArg);
-    }
-  }
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Conversion
-
-FailureOr<Block *> ArgConverter::convertSignature(
-    Block *block, const TypeConverter *converter,
-    ConversionValueMapping &mapping,
-    SmallVectorImpl<BlockArgument> &argReplacements) {
-  // Check if the block was already converted.
-  // * If the block is mapped in `conversionInfo`, it is a converted block.
-  // * If the block is detached, conservatively assume that it is going to be
-  //   deleted; it is likely the old block (before it was converted).
-  if (conversionInfo.count(block) || !block->getParent())
-    return block;
-  // If a converter wasn't provided, and the block wasn't already converted,
-  // there is nothing we can do.
-  if (!converter)
-    return failure();
-
-  // Try to convert the signature for the block with the provided converter.
-  if (auto conversion = converter->convertBlockSignature(block))
-    return applySignatureConversion(block, converter, *conversion, mapping,
-                                    argReplacements);
-  return failure();
-}
-
-Block *ArgConverter::applySignatureConversion(
-    Block *block, const TypeConverter *converter,
-    TypeConverter::SignatureConversion &signatureConversion,
-    ConversionValueMapping &mapping,
-    SmallVectorImpl<BlockArgument> &argReplacements) {
-  // If no arguments are being changed or added, there is nothing to do.
-  unsigned origArgCount = block->getNumArguments();
-  auto convertedTypes = signatureConversion.getConvertedTypes();
-  if (origArgCount == 0 && convertedTypes.empty())
-    return block;
-
-  // Split the block at the beginning to get a new block to use for the updated
-  // signature.
-  Block *newBlock = block->splitBlock(block->begin());
-  block->replaceAllUsesWith(newBlock);
-  // Unlink the block, but do not erase it yet, so that the change can be rolled
-  // back.
-  block->getParent()->getBlocks().remove(block);
-
-  // Map all new arguments to the location of the argument they originate from.
-  SmallVector<Location> newLocs(convertedTypes.size(),
-                                rewriter.getUnknownLoc());
-  for (unsigned i = 0; i < origArgCount; ++i) {
-    auto inputMap = signatureConversion.getInputMapping(i);
-    if (!inputMap || inputMap->replacementValue)
-      continue;
-    Location origLoc = block->getArgument(i).getLoc();
-    for (unsigned j = 0; j < inputMap->size; ++j)
-      newLocs[inputMap->inputNo + j] = origLoc;
-  }
-
-  SmallVector<Value, 4> newArgRange(
-      newBlock->addArguments(convertedTypes, newLocs));
-  ArrayRef<Value> newArgs(newArgRange);
-
-  // Remap each of the original arguments as determined by the signature
-  // conversion.
-  ConvertedBlockInfo info(block, converter);
-  info.argInfo.resize(origArgCount);
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(newBlock);
-  for (unsigned i = 0; i != origArgCount; ++i) {
-    auto inputMap = signatureConversion.getInputMapping(i);
-    if (!inputMap)
-      continue;
-    BlockArgument origArg = block->getArgument(i);
-
-    // If inputMap->replacementValue is not nullptr, then the argument is
-    // dropped and a replacement value is provided to be the remappedValue.
-    if (inputMap->replacementValue) {
-      assert(inputMap->size == 0 &&
-             "invalid to provide a replacement value when the argument isn't "
-             "dropped");
-      mapping.map(origArg, inputMap->replacementValue);
-      argReplacements.push_back(origArg);
-      continue;
-    }
-
-    // Otherwise, this is a 1->1+ mapping.
-    auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
-    Value newArg;
-
-    // If this is a 1->1 mapping and the types of new and replacement arguments
-    // match (i.e. it's an identity map), then the argument is mapped to its
-    // original type.
-    // FIXME: We simply pass through the replacement argument if there wasn't a
-    // converter, which isn't great as it allows implicit type conversions to
-    // appear. We should properly restructure this code to handle cases where a
-    // converter isn't provided and also to properly handle the case where an
-    // argument materialization is actually a temporary source materialization
-    // (e.g. in the case of 1->N).
-    if (replArgs.size() == 1 &&
-        (!converter || replArgs[0].getType() == origArg.getType())) {
-      newArg = replArgs.front();
-    } else {
-      Type origOutputType = origArg.getType();
-
-      // Legalize the argument output type.
-      Type outputType = origOutputType;
-      if (Type legalOutputType = converter->convertType(outputType))
-        outputType = legalOutputType;
-
-      newArg = buildUnresolvedArgumentMaterialization(
-          rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
-          converter, unresolvedMaterializations);
-    }
-
-    mapping.map(origArg, newArg);
-    argReplacements.push_back(origArg);
-    info.argInfo[i] =
-        ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
-  }
-
-  conversionInfo.insert({newBlock, std::move(info)});
-  return newBlock;
-}
-
 //===----------------------------------------------------------------------===//
 // IR rewrites
 //===----------------------------------------------------------------------===//
@@ -694,6 +330,12 @@ class IRRewrite {
   /// Commit the rewrite.
   virtual void commit() {}
 
+  /// Erase the given op (unless it was already erased).
+  void eraseOp(Operation *op);
+
+  /// Erase the given block (unless it was already erased).
+  void eraseBlock(Block *block);
+
   Kind getKind() const { return kind; }
 
   static bool classof(const IRRewrite *rewrite) { return true; }
@@ -744,8 +386,7 @@ class CreateBlockRewrite : public BlockRewrite {
     auto &blockOps = block->getOperations();
     while (!blockOps.empty())
       blockOps.remove(blockOps.begin());
-    block->dropAllDefinedValueUses();
-    block->erase();
+    eraseBlock(block);
   }
 };
 
@@ -881,8 +522,7 @@ class SplitBlockRewrite : public BlockRewrite {
     // Merge back the block that was split out.
     originalBlock->getOperations().splice(originalBlock->end(),
                                           block->getOperations());
-    block->dropAllDefinedValueUses();
-    block->erase();
+    eraseBlock(block);
   }
 
 private:
@@ -890,20 +530,59 @@ class SplitBlockRewrite : public BlockRewrite {
   Block *originalBlock;
 };
 
+/// This structure contains the information pertaining to an argument that has
+/// been converted.
+struct ConvertedArgInfo {
+  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
+                   Value castValue = nullptr)
+      : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
+
+  /// The start index of in the new argument list that contains arguments that
+  /// replace the original.
+  unsigned newArgIdx;
+
+  /// The number of arguments that replaced the original argument.
+  unsigned newArgSize;
+
+  /// The cast value that was created to cast from the new arguments to the
+  /// old. This only used if 'newArgSize' > 1.
+  Value castValue;
+};
+
 /// Block type conversion. This rewrite is partially reflected in the IR.
 class BlockTypeConversionRewrite : public BlockRewrite {
 public:
-  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                             Block *block)
-      : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {}
+  BlockTypeConversionRewrite(
+      ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+      Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
+      const TypeConverter *converter)
+      : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
+        origBlock(origBlock), argInfo(argInfo), converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::BlockTypeConversion;
   }
 
-  // TODO: Block type conversions are currently committed in
-  // `ArgConverter::applyRewrites`. This should be done in the "commit" method.
+  /// Materialize any necessary conversions for converted arguments that have
+  /// live users, using the provided `findLiveUser` to search for a user that
+  /// survives the conversion process.
+  LogicalResult
+  materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+
+  void commit() override;
+
   void rollback() override;
+
+private:
+  /// The original block that was requested to have its signature converted.
+  Block *origBlock;
+
+  /// The conversion information for each of the arguments. The information is
+  /// std::nullopt if the argument was dropped during conversion.
+  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+
+  /// The type converter used to convert the arguments.
+  const TypeConverter *converter;
 };
 
 /// An operation rewrite.
@@ -949,8 +628,8 @@ class MoveOperationRewrite : public OperationRewrite {
   // The block in which this operation was previously contained.
   Block *block;
 
-  // The original successor of this operation before it was moved. "nullptr" if
-  // this operation was the only operation in the region.
+  // The original successor of this operation before it was moved. "nullptr"
+  // if this operation was the only operation in the region.
   Operation *insertBeforeOp;
 };
 
@@ -997,6 +676,26 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
   });
 }
 
+/// Find the single rewrite object of the specified type and block among the
+/// given rewrites. In debug mode, asserts that there is mo more than one such
+/// object. Return "nullptr" if no object was found.
+template <typename RewriteTy, typename R>
+static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
+  RewriteTy *result = nullptr;
+  for (auto &rewrite : rewrites) {
+    auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
+    if (rewriteTy && rewriteTy->getBlock() == block) {
+#ifndef NDEBUG
+      assert(!result && "expected single matching rewrite");
+      result = rewriteTy;
+#else
+      return rewriteTy;
+#endif // NDEBUG
+    }
+  }
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriterImpl
 //===----------------------------------------------------------------------===//
@@ -1004,7 +703,7 @@ namespace mlir {
 namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
-      : argConverter(rewriter, unresolvedMaterializations),
+      : rewriter(rewriter), eraseRewriter(rewriter.getContext()),
         notifyCallback(nullptr) {}
 
   /// Cleanup and destroy any generated rewrite operations. This method is
@@ -1054,15 +753,33 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// removes them from being considered for legalization.
   void markNestedOpsIgnored(Operation *op);
 
+  /// Detach any operations nested in the given operation from their parent
+  /// blocks, and erase the given operation. This can be used when the nested
+  /// operations are scheduled for erasure themselves, so deleting the regions
+  /// of the given operation together with their content would result in
+  /// double-free. This happens, for example, when rolling back op creation in
+  /// the reverse order and if the nested ops were created before the parent op.
+  /// This function does not need to collect nested ops recursively because it
+  /// is expected to also be called for each nested op when it is about to be
+  /// deleted.
+  void detachNestedAndErase(Operation *op);
+
   //===--------------------------------------------------------------------===//
   // Type Conversion
   //===--------------------------------------------------------------------===//
 
-  /// Convert the signature of the given block.
+  /// Attempt to convert the signature of the given block, if successful a new
+  /// block is returned containing the new arguments. Returns `block` if it did
+  /// not require conversion.
   FailureOr<Block *> convertBlockSignature(
       Block *block, const TypeConverter *converter,
       TypeConverter::SignatureConversion *conversion = nullptr);
 
+  /// Convert the types of non-entry block arguments within the given region.
+  LogicalResult convertNonEntryRegionTypes(
+      Region *region, const TypeConverter &converter,
+      ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
+
   /// Apply a signature conversion on the given region, using `converter` for
   /// materializations if not null.
   Block *
@@ -1075,10 +792,15 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   convertRegionTypes(Region *region, const TypeConverter &converter,
                      TypeConverter::SignatureConversion *entryConversion);
 
-  /// Convert the types of non-entry block arguments within the given region.
-  LogicalResult convertNonEntryRegionTypes(
-      Region *region, const TypeConverter &converter,
-      ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
+  /// Apply the given signature conversion on the given block. The new block
+  /// containing the updated signature is returned. If no conversions were
+  /// necessary, e.g. if the block has no arguments, `block` is returned.
+  /// `converter` is used to generate any necessary cast operations that
+  /// translate between the origin argument types and those specified in the
+  /// signature conversion.
+  Block *applySignatureConversion(
+      Block *block, const TypeConverter *converter,
+      TypeConverter::SignatureConversion &signatureConversion);
 
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
@@ -1110,17 +832,54 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
 
+  //===--------------------------------------------------------------------===//
+  // IR Erasure
+  //===--------------------------------------------------------------------===//
+
+  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
+  /// operation or block is erased multiple times. This rewriter assumes that
+  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
+  struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
+  public:
+    SingleEraseRewriter(MLIRContext *context)
+        : RewriterBase(context, /*listener=*/this) {}
+
+    /// Erase the given op (unless it was already erased).
+    void eraseOp(Operation *op) override {
+      if (erased.contains(op))
+        return;
+      op->dropAllUses();
+      RewriterBase::eraseOp(op);
+    }
+
+    /// Erase the given block (unless it was already erased).
+    void eraseBlock(Block *block) override {
+      if (erased.contains(block))
+        return;
+      block->dropAllDefinedValueUses();
+      RewriterBase::eraseBlock(block);
+    }
+
+    void notifyOperationRemoved(Operation *op) override { erased.insert(op); }
+    void notifyBlockRemoved(Block *block) override { erased.insert(block); }
+
+    /// Pointers to all erased operations and blocks.
+    SetVector<void *> erased;
+  };
+
   //===--------------------------------------------------------------------===//
   // State
   //===--------------------------------------------------------------------===//
 
+  PatternRewriter &rewriter;
+
+  /// This rewriter must be used for erasing ops/blocks.
+  SingleEraseRewriter eraseRewriter;
+
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
 
-  /// Utility used to convert block arguments.
-  ArgConverter argConverter;
-
   /// Ordered vector of all of the newly created operations during conversion.
   SmallVector<Operation *> createdOps;
 
@@ -1177,20 +936,102 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 } // namespace detail
 } // namespace mlir
 
+void IRRewrite::eraseOp(Operation *op) {
+  rewriterImpl.eraseRewriter.eraseOp(op);
+}
+
+void IRRewrite::eraseBlock(Block *block) {
+  rewriterImpl.eraseRewriter.eraseBlock(block);
+}
+
+void BlockTypeConversionRewrite::commit() {
+  // Process the remapping for each of the original arguments.
+  for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
+    std::optional<ConvertedArgInfo> &info = argInfo[i];
+    BlockArgument origArg = origBlock->getArgument(i);
+
+    // Handle the case of a 1->0 value mapping.
+    if (!info) {
+      if (Value newArg =
+              rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+        origArg.replaceAllUsesWith(newArg);
+      continue;
+    }
+
+    // Otherwise this is a 1->1+ value mapping.
+    Value castValue = info->castValue;
+    assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
+
+    // If the argument is still used, replace it with the generated cast.
+    if (!origArg.use_empty()) {
+      origArg.replaceAllUsesWith(
+          rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
+    }
+  }
+
+  delete origBlock;
+  origBlock = nullptr;
+}
+
 void BlockTypeConversionRewrite::rollback() {
-  // Undo the type conversion.
-  rewriterImpl.argConverter.discardRewrites(block);
-}
-
-/// Detach any operations nested in the given operation from their parent
-/// blocks, and erase the given operation. This can be used when the nested
-/// operations are scheduled for erasure themselves, so deleting the regions of
-/// the given operation together with their content would result in double-free.
-/// This happens, for example, when rolling back op creation in the reverse
-/// order and if the nested ops were created before the parent op. This function
-/// does not need to collect nested ops recursively because it is expected to
-/// also be called for each nested op when it is about to be deleted.
-static void detachNestedAndErase(Operation *op) {
+  // Drop all uses of the new block arguments and replace uses of the new block.
+  for (int i = block->getNumArguments() - 1; i >= 0; --i)
+    block->getArgument(i).dropAllUses();
+  block->replaceAllUsesWith(origBlock);
+
+  // Move the operations back the original block, move the original block back
+  // into its original location and the delete the new block.
+  origBlock->getOperations().splice(origBlock->end(), block->getOperations());
+  block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
+  eraseBlock(block);
+}
+
+LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
+    function_ref<Operation *(Value)> findLiveUser) {
+  // Process the remapping for each of the original arguments.
+  for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
+    // If the type of this argument changed and the argument is still live, we
+    // need to materialize a conversion.
+    BlockArgument origArg = origBlock->getArgument(i);
+    if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+      continue;
+    Operation *liveUser = findLiveUser(origArg);
+    if (!liveUser)
+      continue;
+
+    Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
+    bool isDroppedArg = replacementValue == origArg;
+    if (isDroppedArg)
+      rewriterImpl.rewriter.setInsertionPointToStart(getBlock());
+    else
+      rewriterImpl.rewriter.setInsertionPointAfterValue(replacementValue);
+    Value newArg;
+    if (converter) {
+      newArg = converter->materializeSourceConversion(
+          rewriterImpl.rewriter, origArg.getLoc(), origArg.getType(),
+          isDroppedArg ? ValueRange() : ValueRange(replacementValue));
+      assert((!newArg || newArg.getType() == origArg.getType()) &&
+             "materialization hook did not provide a value of the expected "
+             "type");
+    }
+    if (!newArg) {
+      InFlightDiagnostic diag =
+          emitError(origArg.getLoc())
+          << "failed to materialize conversion for block argument #" << i
+          << " that remained live after conversion, type was "
+          << origArg.getType();
+      if (!isDroppedArg)
+        diag << ", with target type " << replacementValue.getType();
+      diag.attachNote(liveUser->getLoc())
+          << "see existing live user here: " << *liveUser;
+      return failure();
+    }
+    rewriterImpl.mapping.map(origArg, newArg);
+  }
+  return success();
+}
+
+void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
   for (Region &region : op->getRegions()) {
     for (Block &block : region.getBlocks()) {
       while (!block.getOperations().empty())
@@ -1198,8 +1039,7 @@ static void detachNestedAndErase(Operation *op) {
       block.dropAllDefinedValueUses();
     }
   }
-  op->dropAllUses();
-  op->erase();
+  eraseRewriter.eraseOp(op);
 }
 
 void ConversionPatternRewriterImpl::discardRewrites() {
@@ -1218,11 +1058,6 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     for (OpResult result : repl.first->getResults())
       if (Value newValue = mapping.lookupOrNull(result, result.getType()))
         result.replaceAllUsesWith(newValue);
-
-    // If this operation defines any regions, drop any pending argument
-    // rewrites.
-    if (repl.first->getNumRegions())
-      argConverter.notifyOpRemoved(repl.first);
   }
 
   // Apply all of the requested argument replacements.
@@ -1249,22 +1084,16 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 
   // Drop all of the unresolved materialization operations created during
   // conversion.
-  for (auto &mat : unresolvedMaterializations) {
-    mat.getOp()->dropAllUses();
-    mat.getOp()->erase();
-  }
+  for (auto &mat : unresolvedMaterializations)
+    eraseRewriter.eraseOp(mat.getOp());
 
   // In a second pass, erase all of the replaced operations in reverse. This
   // allows processing nested operations before their parent region is
   // destroyed. Because we process in reverse order, producers may be deleted
   // before their users (a pattern deleting a producer and then the consumer)
   // so we first drop all uses explicitly.
-  for (auto &repl : llvm::reverse(replacements)) {
-    repl.first->dropAllUses();
-    repl.first->erase();
-  }
-
-  argConverter.applyRewrites(mapping);
+  for (auto &repl : llvm::reverse(replacements))
+    eraseRewriter.eraseOp(repl.first);
 
   // Commit all rewrites.
   for (auto &rewrite : rewrites)
@@ -1277,7 +1106,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
                        replacements.size(), argReplacements.size(),
-                       rewrites.size(), ignoredOps.size());
+                       rewrites.size(), ignoredOps.size(),
+                       eraseRewriter.erased.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1325,6 +1155,9 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   while (!operationsWithChangedResults.empty() &&
          operationsWithChangedResults.back() >= state.numReplacements)
     operationsWithChangedResults.pop_back();
+
+  while (eraseRewriter.erased.size() != state.numErased)
+    eraseRewriter.erased.pop_back();
 }
 
 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
@@ -1413,18 +1246,28 @@ void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
     Block *block, const TypeConverter *converter,
     TypeConverter::SignatureConversion *conversion) {
-  FailureOr<Block *> result =
-      conversion ? argConverter.applySignatureConversion(
-                       block, converter, *conversion, mapping, argReplacements)
-                 : argConverter.convertSignature(block, converter, mapping,
-                                                 argReplacements);
-  if (failed(result))
+  // Check if the block was already converted.
+  // * Case 1: `block` is the converted block.
+  if (auto *typeConversionRewrite =
+          findSingleRewrite<BlockTypeConversionRewrite>(rewrites, block))
+    return block;
+  // * Case 2: `block` is detached. Conservatively assume that it is going to be
+  //   deleted; it is likely the old block (before it was converted).
+  if (!block->getParent())
+    return block;
+
+  if (conversion)
+    return applySignatureConversion(block, converter, *conversion);
+
+  // If a converter wasn't provided, and the block wasn't already converted,
+  // there is nothing we can do.
+  if (!converter)
     return failure();
-  if (Block *newBlock = *result) {
-    if (newBlock != block)
-      appendRewrite<BlockTypeConversionRewrite>(newBlock);
-  }
-  return result;
+
+  // Try to convert the signature for the block with the provided converter.
+  if (auto conversion = converter->convertBlockSignature(block))
+    return applySignatureConversion(block, converter, *conversion);
+  return failure();
 }
 
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
@@ -1478,6 +1321,102 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
   return success();
 }
 
+Block *ConversionPatternRewriterImpl::applySignatureConversion(
+    Block *block, const TypeConverter *converter,
+    TypeConverter::SignatureConversion &signatureConversion) {
+  // If no arguments are being changed or added, there is nothing to do.
+  unsigned origArgCount = block->getNumArguments();
+  auto convertedTypes = signatureConversion.getConvertedTypes();
+  if (origArgCount == 0 && convertedTypes.empty())
+    return block;
+
+  // Split the block at the beginning to get a new block to use for the updated
+  // signature.
+  Block *newBlock = block->splitBlock(block->begin());
+  block->replaceAllUsesWith(newBlock);
+  // Unlink the block, but do not erase it yet, so that the change can be rolled
+  // back.
+  block->getParent()->getBlocks().remove(block);
+
+  // Map all new arguments to the location of the argument they originate from.
+  SmallVector<Location> newLocs(convertedTypes.size(),
+                                rewriter.getUnknownLoc());
+  for (unsigned i = 0; i < origArgCount; ++i) {
+    auto inputMap = signatureConversion.getInputMapping(i);
+    if (!inputMap || inputMap->replacementValue)
+      continue;
+    Location origLoc = block->getArgument(i).getLoc();
+    for (unsigned j = 0; j < inputMap->size; ++j)
+      newLocs[inputMap->inputNo + j] = origLoc;
+  }
+
+  SmallVector<Value, 4> newArgRange(
+      newBlock->addArguments(convertedTypes, newLocs));
+  ArrayRef<Value> newArgs(newArgRange);
+
+  // Remap each of the original arguments as determined by the signature
+  // conversion.
+  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+  argInfo.resize(origArgCount);
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToStart(newBlock);
+  for (unsigned i = 0; i != origArgCount; ++i) {
+    auto inputMap = signatureConversion.getInputMapping(i);
+    if (!inputMap)
+      continue;
+    BlockArgument origArg = block->getArgument(i);
+
+    // If inputMap->replacementValue is not nullptr, then the argument is
+    // dropped and a replacement value is provided to be the remappedValue.
+    if (inputMap->replacementValue) {
+      assert(inputMap->size == 0 &&
+             "invalid to provide a replacement value when the argument isn't "
+             "dropped");
+      mapping.map(origArg, inputMap->replacementValue);
+      argReplacements.push_back(origArg);
+      continue;
+    }
+
+    // Otherwise, this is a 1->1+ mapping.
+    auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
+    Value newArg;
+
+    // If this is a 1->1 mapping and the types of new and replacement arguments
+    // match (i.e. it's an identity map), then the argument is mapped to its
+    // original type.
+    // FIXME: We simply pass through the replacement argument if there wasn't a
+    // converter, which isn't great as it allows implicit type conversions to
+    // appear. We should properly restructure this code to handle cases where a
+    // converter isn't provided and also to properly handle the case where an
+    // argument materialization is actually a temporary source materialization
+    // (e.g. in the case of 1->N).
+    if (replArgs.size() == 1 &&
+        (!converter || replArgs[0].getType() == origArg.getType())) {
+      newArg = replArgs.front();
+    } else {
+      Type origOutputType = origArg.getType();
+
+      // Legalize the argument output type.
+      Type outputType = origOutputType;
+      if (Type legalOutputType = converter->convertType(outputType))
+        outputType = legalOutputType;
+
+      newArg = buildUnresolvedArgumentMaterialization(
+          rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
+          converter, unresolvedMaterializations);
+    }
+
+    mapping.map(origArg, newArg);
+    argReplacements.push_back(origArg);
+    argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
+  }
+
+  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
+                                            converter);
+  return newBlock;
+}
+
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -2603,8 +2542,11 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
     });
     return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
   };
-  return rewriterImpl.argConverter.materializeLiveConversions(
-      rewriterImpl.mapping, rewriter, findLiveUser);
+  for (auto &r : rewriterImpl.rewrites)
+    if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get()))
+      if (failed(rewrite->materializeLiveConversions(findLiveUser)))
+        return failure();
+  return success();
 }
 
 /// Replace the results of a materialization operation with the given values.



More information about the llvm-branch-commits mailing list