[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Turn block type conversion into `IRRewrite` (PR #81756)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Feb 16 07:17:19 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81756

>From 5dc79f6af9ac00a61767062980b13eb4ae8d2571 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 16 Feb 2024 15:13:37 +0000
Subject: [PATCH 1/2] [mlir][Transforms] Dialect conversion: Improve signature
 conversion API

This commit improves the block signature conversion API of the dialect conversion.

There is the following comment in `ArgConverter::applySignatureConversion`:
```
// If no arguments are being changed or added, there is nothing to do.
```

However, the implementation actually used to replace a block with a new block even if the block argument types do not change (i.e., there is "nothing to do"). This is fixed in this commit. The documentation of the public `ConversionPatternRewriter` API is updated accordingly.

This commit also removes a check that used to *sometimes* skip a block signature conversion if the block was already converted. This is not consistent with the public `ConversionPatternRewriter` API; block should always be converted, regardless of whether they were already converted or not.

Block signature conversion also used to be silently skipped when the specified block was detached. Instead of silently skipping, an assertion is triggered. Attempting to convert a detached block (which is likely an erased block) is invalid API usage.
---
 mlir/include/mlir/Transforms/DialectConversion.h | 12 +++++++++---
 mlir/lib/Transforms/Utils/DialectConversion.cpp  | 10 +++-------
 2 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 0d7722aa07ee38..2575be4cdea1ac 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -663,6 +663,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Apply a signature conversion to the entry block of the given region. This
   /// replaces the entry block with a new block containing the updated
   /// signature. The new entry block to the region is returned for convenience.
+  /// If no block argument types are changing, the entry original block will be
+  /// left in place and returned.
   ///
   /// If provided, `converter` will be used for any materializations.
   Block *
@@ -671,8 +673,11 @@ class ConversionPatternRewriter final : public PatternRewriter {
                            const TypeConverter *converter = nullptr);
 
   /// Convert the types of block arguments within the given region. This
-  /// replaces each block with a new block containing the updated signature. The
-  /// entry block may have a special conversion if `entryConversion` is
+  /// replaces each block with a new block containing the updated signature. If
+  /// an updated signature would match the current signature, the respective
+  /// block is left in place as is.
+  ///
+  /// The entry block may have a special conversion if `entryConversion` is
   /// provided. On success, the new entry block to the region is returned for
   /// convenience. Otherwise, failure is returned.
   FailureOr<Block *> convertRegionTypes(
@@ -681,7 +686,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
 
   /// Convert the types of block arguments within the given region except for
   /// the entry region. This replaces each non-entry block with a new block
-  /// containing the updated signature.
+  /// containing the updated signature. If an updated signature would match the
+  /// current signature, the respective block is left in place as is.
   ///
   /// If special conversion behavior is needed for the non-entry blocks (for
   /// example, we need to convert only a subset of a BB arguments), such
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 35028001a03dd9..c16bb144efecf5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -544,12 +544,8 @@ FailureOr<Block *> ArgConverter::convertSignature(
     Block *block, const TypeConverter *converter,
     ConversionValueMapping &mapping,
     SmallVectorImpl<BlockArgument> &argReplacements) {
-  // Check if the block was already converted.
-  // * If the block is mapped in `conversionInfo`, it is a converted block.
-  // * If the block is detached, conservatively assume that it is going to be
-  //   deleted; it is likely the old block (before it was converted).
-  if (conversionInfo.count(block) || !block->getParent())
-    return block;
+  assert(block->getParent() && "cannot convert signature of detached block");
+
   // If a converter wasn't provided, and the block wasn't already converted,
   // there is nothing we can do.
   if (!converter)
@@ -570,7 +566,7 @@ Block *ArgConverter::applySignatureConversion(
   // If no arguments are being changed or added, there is nothing to do.
   unsigned origArgCount = block->getNumArguments();
   auto convertedTypes = signatureConversion.getConvertedTypes();
-  if (origArgCount == 0 && convertedTypes.empty())
+  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
     return block;
 
   // Split the block at the beginning to get a new block to use for the updated

>From 61e82f6ab048ddca789a4e20e6b56781915157bf Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 16 Feb 2024 15:16:16 +0000
Subject: [PATCH 2/2] [mlir][Transforms][NFC] Turn block type convertion into
 `IRRewrite`

This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure).

Until now, the signature conversion of a block was only a "partial" IR rewrite. Rollbacks were triggered via `BlockTypeConversionRewrite::rollback`, but there was no `BlockTypeConversionRewrite::commit` equivalent.

Overview of changes:
* Remove `ArgConverter`, an internal helper class that kept track of all block type conversions. There is now a separate `BlockTypeConversionRewrite` for each block type conversion.
* No more special handling for block type conversions. They are now normal "IR rewrites", just like "block creation" or "block movement". In particular, trigger "commits" of block type conversion via `BlockTypeConversionRewrite::commit`.
* Remove `ArgConverter::notifyOpRemoved`. This function was used to inform the `ArgConverter` that an operation was erased, to prevent a double-free of operations in certain situations. It would be unpractical to add a `notifyOpRemoved` API to `IRRewrite`. Instead, erasing ops/block should go through a new `SingleEraseRewriter` (that is owned by the `ConversionPatternRewriterImpl`) if there is chance of double-free. This rewriter ignores `eraseOp`/`eraseBlock` if the op/block was already freed.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 .../Transforms/Utils/DialectConversion.cpp    | 796 ++++++++----------
 1 file changed, 366 insertions(+), 430 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c16bb144efecf5..30133a14dbae56 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,12 +154,13 @@ namespace {
 struct RewriterState {
   RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
                 unsigned numReplacements, unsigned numArgReplacements,
-                unsigned numRewrites, unsigned numIgnoredOperations)
+                unsigned numRewrites, unsigned numIgnoredOperations,
+                unsigned numErased)
       : numCreatedOps(numCreatedOps),
         numUnresolvedMaterializations(numUnresolvedMaterializations),
         numReplacements(numReplacements),
         numArgReplacements(numArgReplacements), numRewrites(numRewrites),
-        numIgnoredOperations(numIgnoredOperations) {}
+        numIgnoredOperations(numIgnoredOperations), numErased(numErased) {}
 
   /// The current number of created operations.
   unsigned numCreatedOps;
@@ -178,6 +179,9 @@ struct RewriterState {
 
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
+
+  /// The current number of erased operations/blocks.
+  unsigned numErased;
 };
 
 //===----------------------------------------------------------------------===//
@@ -292,370 +296,6 @@ static Value buildUnresolvedTargetMaterialization(
       outputType, outputType, converter, unresolvedMaterializations);
 }
 
-//===----------------------------------------------------------------------===//
-// ArgConverter
-//===----------------------------------------------------------------------===//
-namespace {
-/// This class provides a simple interface for converting the types of block
-/// arguments. This is done by creating a new block that contains the new legal
-/// types and extracting the block that contains the old illegal types to allow
-/// for undoing pending rewrites in the case of failure.
-struct ArgConverter {
-  ArgConverter(
-      PatternRewriter &rewriter,
-      SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations)
-      : rewriter(rewriter),
-        unresolvedMaterializations(unresolvedMaterializations) {}
-
-  /// This structure contains the information pertaining to an argument that has
-  /// been converted.
-  struct ConvertedArgInfo {
-    ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
-                     Value castValue = nullptr)
-        : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
-    /// The start index of in the new argument list that contains arguments that
-    /// replace the original.
-    unsigned newArgIdx;
-
-    /// The number of arguments that replaced the original argument.
-    unsigned newArgSize;
-
-    /// The cast value that was created to cast from the new arguments to the
-    /// old. This only used if 'newArgSize' > 1.
-    Value castValue;
-  };
-
-  /// This structure contains information pertaining to a block that has had its
-  /// signature converted.
-  struct ConvertedBlockInfo {
-    ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter)
-        : origBlock(origBlock), converter(converter) {}
-
-    /// The original block that was requested to have its signature converted.
-    Block *origBlock;
-
-    /// The conversion information for each of the arguments. The information is
-    /// std::nullopt if the argument was dropped during conversion.
-    SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
-    /// The type converter used to convert the arguments.
-    const TypeConverter *converter;
-  };
-
-  //===--------------------------------------------------------------------===//
-  // Rewrite Application
-  //===--------------------------------------------------------------------===//
-
-  /// Erase any rewrites registered for the blocks within the given operation
-  /// which is about to be removed. This merely drops the rewrites without
-  /// undoing them.
-  void notifyOpRemoved(Operation *op);
-
-  /// Cleanup and undo any generated conversions for the arguments of block.
-  /// This method replaces the new block with the original, reverting the IR to
-  /// its original state.
-  void discardRewrites(Block *block);
-
-  /// Fully replace uses of the old arguments with the new.
-  void applyRewrites(ConversionValueMapping &mapping);
-
-  /// Materialize any necessary conversions for converted arguments that have
-  /// live users, using the provided `findLiveUser` to search for a user that
-  /// survives the conversion process.
-  LogicalResult
-  materializeLiveConversions(ConversionValueMapping &mapping,
-                             OpBuilder &builder,
-                             function_ref<Operation *(Value)> findLiveUser);
-
-  //===--------------------------------------------------------------------===//
-  // Conversion
-  //===--------------------------------------------------------------------===//
-
-  /// Attempt to convert the signature of the given block, if successful a new
-  /// block is returned containing the new arguments. Returns `block` if it did
-  /// not require conversion.
-  FailureOr<Block *>
-  convertSignature(Block *block, const TypeConverter *converter,
-                   ConversionValueMapping &mapping,
-                   SmallVectorImpl<BlockArgument> &argReplacements);
-
-  /// Apply the given signature conversion on the given block. The new block
-  /// containing the updated signature is returned. If no conversions were
-  /// necessary, e.g. if the block has no arguments, `block` is returned.
-  /// `converter` is used to generate any necessary cast operations that
-  /// translate between the origin argument types and those specified in the
-  /// signature conversion.
-  Block *applySignatureConversion(
-      Block *block, const TypeConverter *converter,
-      TypeConverter::SignatureConversion &signatureConversion,
-      ConversionValueMapping &mapping,
-      SmallVectorImpl<BlockArgument> &argReplacements);
-
-  /// A collection of blocks that have had their arguments converted. This is a
-  /// map from the new replacement block, back to the original block.
-  llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
-
-  /// The pattern rewriter to use when materializing conversions.
-  PatternRewriter &rewriter;
-
-  /// An ordered set of unresolved materializations during conversion.
-  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations;
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// Rewrite Application
-
-void ArgConverter::notifyOpRemoved(Operation *op) {
-  if (conversionInfo.empty())
-    return;
-
-  for (Region &region : op->getRegions()) {
-    for (Block &block : region) {
-      // Drop any rewrites from within.
-      for (Operation &nestedOp : block)
-        if (nestedOp.getNumRegions())
-          notifyOpRemoved(&nestedOp);
-
-      // Check if this block was converted.
-      auto it = conversionInfo.find(&block);
-      if (it == conversionInfo.end())
-        continue;
-
-      // Drop all uses of the original arguments and delete the original block.
-      Block *origBlock = it->second.origBlock;
-      for (BlockArgument arg : origBlock->getArguments())
-        arg.dropAllUses();
-      conversionInfo.erase(it);
-    }
-  }
-}
-
-void ArgConverter::discardRewrites(Block *block) {
-  auto it = conversionInfo.find(block);
-  if (it == conversionInfo.end())
-    return;
-  Block *origBlock = it->second.origBlock;
-
-  // Drop all uses of the new block arguments and replace uses of the new block.
-  for (int i = block->getNumArguments() - 1; i >= 0; --i)
-    block->getArgument(i).dropAllUses();
-  block->replaceAllUsesWith(origBlock);
-
-  // Move the operations back the original block, move the original block back
-  // into its original location and the delete the new block.
-  origBlock->getOperations().splice(origBlock->end(), block->getOperations());
-  block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
-  block->erase();
-
-  conversionInfo.erase(it);
-}
-
-void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
-  for (auto &info : conversionInfo) {
-    ConvertedBlockInfo &blockInfo = info.second;
-    Block *origBlock = blockInfo.origBlock;
-
-    // Process the remapping for each of the original arguments.
-    for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
-      std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
-      BlockArgument origArg = origBlock->getArgument(i);
-
-      // Handle the case of a 1->0 value mapping.
-      if (!argInfo) {
-        if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType()))
-          origArg.replaceAllUsesWith(newArg);
-        continue;
-      }
-
-      // Otherwise this is a 1->1+ value mapping.
-      Value castValue = argInfo->castValue;
-      assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
-      // If the argument is still used, replace it with the generated cast.
-      if (!origArg.use_empty()) {
-        origArg.replaceAllUsesWith(
-            mapping.lookupOrDefault(castValue, origArg.getType()));
-      }
-    }
-
-    delete origBlock;
-    blockInfo.origBlock = nullptr;
-  }
-}
-
-LogicalResult ArgConverter::materializeLiveConversions(
-    ConversionValueMapping &mapping, OpBuilder &builder,
-    function_ref<Operation *(Value)> findLiveUser) {
-  for (auto &info : conversionInfo) {
-    Block *newBlock = info.first;
-    ConvertedBlockInfo &blockInfo = info.second;
-    Block *origBlock = blockInfo.origBlock;
-
-    // Process the remapping for each of the original arguments.
-    for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
-      // If the type of this argument changed and the argument is still live, we
-      // need to materialize a conversion.
-      BlockArgument origArg = origBlock->getArgument(i);
-      if (mapping.lookupOrNull(origArg, origArg.getType()))
-        continue;
-      Operation *liveUser = findLiveUser(origArg);
-      if (!liveUser)
-        continue;
-
-      Value replacementValue = mapping.lookupOrDefault(origArg);
-      bool isDroppedArg = replacementValue == origArg;
-      if (isDroppedArg)
-        rewriter.setInsertionPointToStart(newBlock);
-      else
-        rewriter.setInsertionPointAfterValue(replacementValue);
-      Value newArg;
-      if (blockInfo.converter) {
-        newArg = blockInfo.converter->materializeSourceConversion(
-            rewriter, origArg.getLoc(), origArg.getType(),
-            isDroppedArg ? ValueRange() : ValueRange(replacementValue));
-        assert((!newArg || newArg.getType() == origArg.getType()) &&
-               "materialization hook did not provide a value of the expected "
-               "type");
-      }
-      if (!newArg) {
-        InFlightDiagnostic diag =
-            emitError(origArg.getLoc())
-            << "failed to materialize conversion for block argument #" << i
-            << " that remained live after conversion, type was "
-            << origArg.getType();
-        if (!isDroppedArg)
-          diag << ", with target type " << replacementValue.getType();
-        diag.attachNote(liveUser->getLoc())
-            << "see existing live user here: " << *liveUser;
-        return failure();
-      }
-      mapping.map(origArg, newArg);
-    }
-  }
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Conversion
-
-FailureOr<Block *> ArgConverter::convertSignature(
-    Block *block, const TypeConverter *converter,
-    ConversionValueMapping &mapping,
-    SmallVectorImpl<BlockArgument> &argReplacements) {
-  assert(block->getParent() && "cannot convert signature of detached block");
-
-  // If a converter wasn't provided, and the block wasn't already converted,
-  // there is nothing we can do.
-  if (!converter)
-    return failure();
-
-  // Try to convert the signature for the block with the provided converter.
-  if (auto conversion = converter->convertBlockSignature(block))
-    return applySignatureConversion(block, converter, *conversion, mapping,
-                                    argReplacements);
-  return failure();
-}
-
-Block *ArgConverter::applySignatureConversion(
-    Block *block, const TypeConverter *converter,
-    TypeConverter::SignatureConversion &signatureConversion,
-    ConversionValueMapping &mapping,
-    SmallVectorImpl<BlockArgument> &argReplacements) {
-  // If no arguments are being changed or added, there is nothing to do.
-  unsigned origArgCount = block->getNumArguments();
-  auto convertedTypes = signatureConversion.getConvertedTypes();
-  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
-    return block;
-
-  // Split the block at the beginning to get a new block to use for the updated
-  // signature.
-  Block *newBlock = block->splitBlock(block->begin());
-  block->replaceAllUsesWith(newBlock);
-  // Unlink the block, but do not erase it yet, so that the change can be rolled
-  // back.
-  block->getParent()->getBlocks().remove(block);
-
-  // Map all new arguments to the location of the argument they originate from.
-  SmallVector<Location> newLocs(convertedTypes.size(),
-                                rewriter.getUnknownLoc());
-  for (unsigned i = 0; i < origArgCount; ++i) {
-    auto inputMap = signatureConversion.getInputMapping(i);
-    if (!inputMap || inputMap->replacementValue)
-      continue;
-    Location origLoc = block->getArgument(i).getLoc();
-    for (unsigned j = 0; j < inputMap->size; ++j)
-      newLocs[inputMap->inputNo + j] = origLoc;
-  }
-
-  SmallVector<Value, 4> newArgRange(
-      newBlock->addArguments(convertedTypes, newLocs));
-  ArrayRef<Value> newArgs(newArgRange);
-
-  // Remap each of the original arguments as determined by the signature
-  // conversion.
-  ConvertedBlockInfo info(block, converter);
-  info.argInfo.resize(origArgCount);
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(newBlock);
-  for (unsigned i = 0; i != origArgCount; ++i) {
-    auto inputMap = signatureConversion.getInputMapping(i);
-    if (!inputMap)
-      continue;
-    BlockArgument origArg = block->getArgument(i);
-
-    // If inputMap->replacementValue is not nullptr, then the argument is
-    // dropped and a replacement value is provided to be the remappedValue.
-    if (inputMap->replacementValue) {
-      assert(inputMap->size == 0 &&
-             "invalid to provide a replacement value when the argument isn't "
-             "dropped");
-      mapping.map(origArg, inputMap->replacementValue);
-      argReplacements.push_back(origArg);
-      continue;
-    }
-
-    // Otherwise, this is a 1->1+ mapping.
-    auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
-    Value newArg;
-
-    // If this is a 1->1 mapping and the types of new and replacement arguments
-    // match (i.e. it's an identity map), then the argument is mapped to its
-    // original type.
-    // FIXME: We simply pass through the replacement argument if there wasn't a
-    // converter, which isn't great as it allows implicit type conversions to
-    // appear. We should properly restructure this code to handle cases where a
-    // converter isn't provided and also to properly handle the case where an
-    // argument materialization is actually a temporary source materialization
-    // (e.g. in the case of 1->N).
-    if (replArgs.size() == 1 &&
-        (!converter || replArgs[0].getType() == origArg.getType())) {
-      newArg = replArgs.front();
-    } else {
-      Type origOutputType = origArg.getType();
-
-      // Legalize the argument output type.
-      Type outputType = origOutputType;
-      if (Type legalOutputType = converter->convertType(outputType))
-        outputType = legalOutputType;
-
-      newArg = buildUnresolvedArgumentMaterialization(
-          rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
-          converter, unresolvedMaterializations);
-    }
-
-    mapping.map(origArg, newArg);
-    argReplacements.push_back(origArg);
-    info.argInfo[i] =
-        ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
-  }
-
-  conversionInfo.insert({newBlock, std::move(info)});
-  return newBlock;
-}
-
 //===----------------------------------------------------------------------===//
 // IR rewrites
 //===----------------------------------------------------------------------===//
@@ -698,6 +338,12 @@ class IRRewrite {
   IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
       : kind(kind), rewriterImpl(rewriterImpl) {}
 
+  /// Erase the given op (unless it was already erased).
+  void eraseOp(Operation *op);
+
+  /// Erase the given block (unless it was already erased).
+  void eraseBlock(Block *block);
+
   const Kind kind;
   ConversionPatternRewriterImpl &rewriterImpl;
 };
@@ -740,8 +386,7 @@ class CreateBlockRewrite : public BlockRewrite {
     auto &blockOps = block->getOperations();
     while (!blockOps.empty())
       blockOps.remove(blockOps.begin());
-    block->dropAllDefinedValueUses();
-    block->erase();
+    eraseBlock(block);
   }
 };
 
@@ -877,8 +522,7 @@ class SplitBlockRewrite : public BlockRewrite {
     // Merge back the block that was split out.
     originalBlock->getOperations().splice(originalBlock->end(),
                                           block->getOperations());
-    block->dropAllDefinedValueUses();
-    block->erase();
+    eraseBlock(block);
   }
 
 private:
@@ -886,20 +530,59 @@ class SplitBlockRewrite : public BlockRewrite {
   Block *originalBlock;
 };
 
+/// This structure contains the information pertaining to an argument that has
+/// been converted.
+struct ConvertedArgInfo {
+  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
+                   Value castValue = nullptr)
+      : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
+
+  /// The start index of in the new argument list that contains arguments that
+  /// replace the original.
+  unsigned newArgIdx;
+
+  /// The number of arguments that replaced the original argument.
+  unsigned newArgSize;
+
+  /// The cast value that was created to cast from the new arguments to the
+  /// old. This only used if 'newArgSize' > 1.
+  Value castValue;
+};
+
 /// Block type conversion. This rewrite is partially reflected in the IR.
 class BlockTypeConversionRewrite : public BlockRewrite {
 public:
-  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                             Block *block)
-      : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {}
+  BlockTypeConversionRewrite(
+      ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+      Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
+      const TypeConverter *converter)
+      : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
+        origBlock(origBlock), argInfo(argInfo), converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::BlockTypeConversion;
   }
 
-  // TODO: Block type conversions are currently committed in
-  // `ArgConverter::applyRewrites`. This should be done in the "commit" method.
+  /// Materialize any necessary conversions for converted arguments that have
+  /// live users, using the provided `findLiveUser` to search for a user that
+  /// survives the conversion process.
+  LogicalResult
+  materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+
+  void commit() override;
+
   void rollback() override;
+
+private:
+  /// The original block that was requested to have its signature converted.
+  Block *origBlock;
+
+  /// The conversion information for each of the arguments. The information is
+  /// std::nullopt if the argument was dropped during conversion.
+  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+
+  /// The type converter used to convert the arguments.
+  const TypeConverter *converter;
 };
 
 /// An operation rewrite.
@@ -945,8 +628,8 @@ class MoveOperationRewrite : public OperationRewrite {
   // The block in which this operation was previously contained.
   Block *block;
 
-  // The original successor of this operation before it was moved. "nullptr" if
-  // this operation was the only operation in the region.
+  // The original successor of this operation before it was moved. "nullptr"
+  // if this operation was the only operation in the region.
   Operation *insertBeforeOp;
 };
 
@@ -993,6 +676,26 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
   });
 }
 
+/// Find the single rewrite object of the specified type and block among the
+/// given rewrites. In debug mode, asserts that there is mo more than one such
+/// object. Return "nullptr" if no object was found.
+template <typename RewriteTy, typename R>
+static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
+  RewriteTy *result = nullptr;
+  for (auto &rewrite : rewrites) {
+    auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
+    if (rewriteTy && rewriteTy->getBlock() == block) {
+#ifndef NDEBUG
+      assert(!result && "expected single matching rewrite");
+      result = rewriteTy;
+#else
+      return rewriteTy;
+#endif // NDEBUG
+    }
+  }
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriterImpl
 //===----------------------------------------------------------------------===//
@@ -1000,7 +703,7 @@ namespace mlir {
 namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
-      : argConverter(rewriter, unresolvedMaterializations),
+      : rewriter(rewriter), eraseRewriter(rewriter.getContext()),
         notifyCallback(nullptr) {}
 
   /// Cleanup and destroy any generated rewrite operations. This method is
@@ -1050,15 +753,33 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// removes them from being considered for legalization.
   void markNestedOpsIgnored(Operation *op);
 
+  /// Detach any operations nested in the given operation from their parent
+  /// blocks, and erase the given operation. This can be used when the nested
+  /// operations are scheduled for erasure themselves, so deleting the regions
+  /// of the given operation together with their content would result in
+  /// double-free. This happens, for example, when rolling back op creation in
+  /// the reverse order and if the nested ops were created before the parent op.
+  /// This function does not need to collect nested ops recursively because it
+  /// is expected to also be called for each nested op when it is about to be
+  /// deleted.
+  void detachNestedAndErase(Operation *op);
+
   //===--------------------------------------------------------------------===//
   // Type Conversion
   //===--------------------------------------------------------------------===//
 
-  /// Convert the signature of the given block.
+  /// Attempt to convert the signature of the given block, if successful a new
+  /// block is returned containing the new arguments. Returns `block` if it did
+  /// not require conversion.
   FailureOr<Block *> convertBlockSignature(
       Block *block, const TypeConverter *converter,
       TypeConverter::SignatureConversion *conversion = nullptr);
 
+  /// Convert the types of non-entry block arguments within the given region.
+  LogicalResult convertNonEntryRegionTypes(
+      Region *region, const TypeConverter &converter,
+      ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
+
   /// Apply a signature conversion on the given region, using `converter` for
   /// materializations if not null.
   Block *
@@ -1071,10 +792,15 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   convertRegionTypes(Region *region, const TypeConverter &converter,
                      TypeConverter::SignatureConversion *entryConversion);
 
-  /// Convert the types of non-entry block arguments within the given region.
-  LogicalResult convertNonEntryRegionTypes(
-      Region *region, const TypeConverter &converter,
-      ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
+  /// Apply the given signature conversion on the given block. The new block
+  /// containing the updated signature is returned. If no conversions were
+  /// necessary, e.g. if the block has no arguments, `block` is returned.
+  /// `converter` is used to generate any necessary cast operations that
+  /// translate between the origin argument types and those specified in the
+  /// signature conversion.
+  Block *applySignatureConversion(
+      Block *block, const TypeConverter *converter,
+      TypeConverter::SignatureConversion &signatureConversion);
 
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
@@ -1106,17 +832,54 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
 
+  //===--------------------------------------------------------------------===//
+  // IR Erasure
+  //===--------------------------------------------------------------------===//
+
+  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
+  /// operation or block is erased multiple times. This rewriter assumes that
+  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
+  struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
+  public:
+    SingleEraseRewriter(MLIRContext *context)
+        : RewriterBase(context, /*listener=*/this) {}
+
+    /// Erase the given op (unless it was already erased).
+    void eraseOp(Operation *op) override {
+      if (erased.contains(op))
+        return;
+      op->dropAllUses();
+      RewriterBase::eraseOp(op);
+    }
+
+    /// Erase the given block (unless it was already erased).
+    void eraseBlock(Block *block) override {
+      if (erased.contains(block))
+        return;
+      block->dropAllDefinedValueUses();
+      RewriterBase::eraseBlock(block);
+    }
+
+    void notifyOperationRemoved(Operation *op) override { erased.insert(op); }
+    void notifyBlockRemoved(Block *block) override { erased.insert(block); }
+
+    /// Pointers to all erased operations and blocks.
+    SetVector<void *> erased;
+  };
+
   //===--------------------------------------------------------------------===//
   // State
   //===--------------------------------------------------------------------===//
 
+  PatternRewriter &rewriter;
+
+  /// This rewriter must be used for erasing ops/blocks.
+  SingleEraseRewriter eraseRewriter;
+
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
 
-  /// Utility used to convert block arguments.
-  ArgConverter argConverter;
-
   /// Ordered vector of all of the newly created operations during conversion.
   SmallVector<Operation *> createdOps;
 
@@ -1173,20 +936,102 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 } // namespace detail
 } // namespace mlir
 
+void IRRewrite::eraseOp(Operation *op) {
+  rewriterImpl.eraseRewriter.eraseOp(op);
+}
+
+void IRRewrite::eraseBlock(Block *block) {
+  rewriterImpl.eraseRewriter.eraseBlock(block);
+}
+
+void BlockTypeConversionRewrite::commit() {
+  // Process the remapping for each of the original arguments.
+  for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
+    std::optional<ConvertedArgInfo> &info = argInfo[i];
+    BlockArgument origArg = origBlock->getArgument(i);
+
+    // Handle the case of a 1->0 value mapping.
+    if (!info) {
+      if (Value newArg =
+              rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+        origArg.replaceAllUsesWith(newArg);
+      continue;
+    }
+
+    // Otherwise this is a 1->1+ value mapping.
+    Value castValue = info->castValue;
+    assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
+
+    // If the argument is still used, replace it with the generated cast.
+    if (!origArg.use_empty()) {
+      origArg.replaceAllUsesWith(
+          rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
+    }
+  }
+
+  delete origBlock;
+  origBlock = nullptr;
+}
+
 void BlockTypeConversionRewrite::rollback() {
-  // Undo the type conversion.
-  rewriterImpl.argConverter.discardRewrites(block);
-}
-
-/// Detach any operations nested in the given operation from their parent
-/// blocks, and erase the given operation. This can be used when the nested
-/// operations are scheduled for erasure themselves, so deleting the regions of
-/// the given operation together with their content would result in double-free.
-/// This happens, for example, when rolling back op creation in the reverse
-/// order and if the nested ops were created before the parent op. This function
-/// does not need to collect nested ops recursively because it is expected to
-/// also be called for each nested op when it is about to be deleted.
-static void detachNestedAndErase(Operation *op) {
+  // Drop all uses of the new block arguments and replace uses of the new block.
+  for (int i = block->getNumArguments() - 1; i >= 0; --i)
+    block->getArgument(i).dropAllUses();
+  block->replaceAllUsesWith(origBlock);
+
+  // Move the operations back the original block, move the original block back
+  // into its original location and the delete the new block.
+  origBlock->getOperations().splice(origBlock->end(), block->getOperations());
+  block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
+  eraseBlock(block);
+}
+
+LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
+    function_ref<Operation *(Value)> findLiveUser) {
+  // Process the remapping for each of the original arguments.
+  for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
+    // If the type of this argument changed and the argument is still live, we
+    // need to materialize a conversion.
+    BlockArgument origArg = origBlock->getArgument(i);
+    if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+      continue;
+    Operation *liveUser = findLiveUser(origArg);
+    if (!liveUser)
+      continue;
+
+    Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
+    bool isDroppedArg = replacementValue == origArg;
+    if (isDroppedArg)
+      rewriterImpl.rewriter.setInsertionPointToStart(getBlock());
+    else
+      rewriterImpl.rewriter.setInsertionPointAfterValue(replacementValue);
+    Value newArg;
+    if (converter) {
+      newArg = converter->materializeSourceConversion(
+          rewriterImpl.rewriter, origArg.getLoc(), origArg.getType(),
+          isDroppedArg ? ValueRange() : ValueRange(replacementValue));
+      assert((!newArg || newArg.getType() == origArg.getType()) &&
+             "materialization hook did not provide a value of the expected "
+             "type");
+    }
+    if (!newArg) {
+      InFlightDiagnostic diag =
+          emitError(origArg.getLoc())
+          << "failed to materialize conversion for block argument #" << i
+          << " that remained live after conversion, type was "
+          << origArg.getType();
+      if (!isDroppedArg)
+        diag << ", with target type " << replacementValue.getType();
+      diag.attachNote(liveUser->getLoc())
+          << "see existing live user here: " << *liveUser;
+      return failure();
+    }
+    rewriterImpl.mapping.map(origArg, newArg);
+  }
+  return success();
+}
+
+void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
   for (Region &region : op->getRegions()) {
     for (Block &block : region.getBlocks()) {
       while (!block.getOperations().empty())
@@ -1194,8 +1039,7 @@ static void detachNestedAndErase(Operation *op) {
       block.dropAllDefinedValueUses();
     }
   }
-  op->dropAllUses();
-  op->erase();
+  eraseRewriter.eraseOp(op);
 }
 
 void ConversionPatternRewriterImpl::discardRewrites() {
@@ -1214,11 +1058,6 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     for (OpResult result : repl.first->getResults())
       if (Value newValue = mapping.lookupOrNull(result, result.getType()))
         result.replaceAllUsesWith(newValue);
-
-    // If this operation defines any regions, drop any pending argument
-    // rewrites.
-    if (repl.first->getNumRegions())
-      argConverter.notifyOpRemoved(repl.first);
   }
 
   // Apply all of the requested argument replacements.
@@ -1245,22 +1084,16 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 
   // Drop all of the unresolved materialization operations created during
   // conversion.
-  for (auto &mat : unresolvedMaterializations) {
-    mat.getOp()->dropAllUses();
-    mat.getOp()->erase();
-  }
+  for (auto &mat : unresolvedMaterializations)
+    eraseRewriter.eraseOp(mat.getOp());
 
   // In a second pass, erase all of the replaced operations in reverse. This
   // allows processing nested operations before their parent region is
   // destroyed. Because we process in reverse order, producers may be deleted
   // before their users (a pattern deleting a producer and then the consumer)
   // so we first drop all uses explicitly.
-  for (auto &repl : llvm::reverse(replacements)) {
-    repl.first->dropAllUses();
-    repl.first->erase();
-  }
-
-  argConverter.applyRewrites(mapping);
+  for (auto &repl : llvm::reverse(replacements))
+    eraseRewriter.eraseOp(repl.first);
 
   // Commit all rewrites.
   for (auto &rewrite : rewrites)
@@ -1273,7 +1106,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
                        replacements.size(), argReplacements.size(),
-                       rewrites.size(), ignoredOps.size());
+                       rewrites.size(), ignoredOps.size(),
+                       eraseRewriter.erased.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1321,6 +1155,9 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   while (!operationsWithChangedResults.empty() &&
          operationsWithChangedResults.back() >= state.numReplacements)
     operationsWithChangedResults.pop_back();
+
+  while (eraseRewriter.erased.size() != state.numErased)
+    eraseRewriter.erased.pop_back();
 }
 
 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
@@ -1409,18 +1246,18 @@ void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
     Block *block, const TypeConverter *converter,
     TypeConverter::SignatureConversion *conversion) {
-  FailureOr<Block *> result =
-      conversion ? argConverter.applySignatureConversion(
-                       block, converter, *conversion, mapping, argReplacements)
-                 : argConverter.convertSignature(block, converter, mapping,
-                                                 argReplacements);
-  if (failed(result))
+  if (conversion)
+    return applySignatureConversion(block, converter, *conversion);
+
+  // If a converter wasn't provided, and the block wasn't already converted,
+  // there is nothing we can do.
+  if (!converter)
     return failure();
-  if (Block *newBlock = *result) {
-    if (newBlock != block)
-      appendRewrite<BlockTypeConversionRewrite>(newBlock);
-  }
-  return result;
+
+  // Try to convert the signature for the block with the provided converter.
+  if (auto conversion = converter->convertBlockSignature(block))
+    return applySignatureConversion(block, converter, *conversion);
+  return failure();
 }
 
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
@@ -1474,6 +1311,102 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
   return success();
 }
 
+Block *ConversionPatternRewriterImpl::applySignatureConversion(
+    Block *block, const TypeConverter *converter,
+    TypeConverter::SignatureConversion &signatureConversion) {
+  // If no arguments are being changed or added, there is nothing to do.
+  unsigned origArgCount = block->getNumArguments();
+  auto convertedTypes = signatureConversion.getConvertedTypes();
+  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
+    return block;
+
+  // Split the block at the beginning to get a new block to use for the updated
+  // signature.
+  Block *newBlock = block->splitBlock(block->begin());
+  block->replaceAllUsesWith(newBlock);
+  // Unlink the block, but do not erase it yet, so that the change can be rolled
+  // back.
+  block->getParent()->getBlocks().remove(block);
+
+  // Map all new arguments to the location of the argument they originate from.
+  SmallVector<Location> newLocs(convertedTypes.size(),
+                                rewriter.getUnknownLoc());
+  for (unsigned i = 0; i < origArgCount; ++i) {
+    auto inputMap = signatureConversion.getInputMapping(i);
+    if (!inputMap || inputMap->replacementValue)
+      continue;
+    Location origLoc = block->getArgument(i).getLoc();
+    for (unsigned j = 0; j < inputMap->size; ++j)
+      newLocs[inputMap->inputNo + j] = origLoc;
+  }
+
+  SmallVector<Value, 4> newArgRange(
+      newBlock->addArguments(convertedTypes, newLocs));
+  ArrayRef<Value> newArgs(newArgRange);
+
+  // Remap each of the original arguments as determined by the signature
+  // conversion.
+  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+  argInfo.resize(origArgCount);
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToStart(newBlock);
+  for (unsigned i = 0; i != origArgCount; ++i) {
+    auto inputMap = signatureConversion.getInputMapping(i);
+    if (!inputMap)
+      continue;
+    BlockArgument origArg = block->getArgument(i);
+
+    // If inputMap->replacementValue is not nullptr, then the argument is
+    // dropped and a replacement value is provided to be the remappedValue.
+    if (inputMap->replacementValue) {
+      assert(inputMap->size == 0 &&
+             "invalid to provide a replacement value when the argument isn't "
+             "dropped");
+      mapping.map(origArg, inputMap->replacementValue);
+      argReplacements.push_back(origArg);
+      continue;
+    }
+
+    // Otherwise, this is a 1->1+ mapping.
+    auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
+    Value newArg;
+
+    // If this is a 1->1 mapping and the types of new and replacement arguments
+    // match (i.e. it's an identity map), then the argument is mapped to its
+    // original type.
+    // FIXME: We simply pass through the replacement argument if there wasn't a
+    // converter, which isn't great as it allows implicit type conversions to
+    // appear. We should properly restructure this code to handle cases where a
+    // converter isn't provided and also to properly handle the case where an
+    // argument materialization is actually a temporary source materialization
+    // (e.g. in the case of 1->N).
+    if (replArgs.size() == 1 &&
+        (!converter || replArgs[0].getType() == origArg.getType())) {
+      newArg = replArgs.front();
+    } else {
+      Type origOutputType = origArg.getType();
+
+      // Legalize the argument output type.
+      Type outputType = origOutputType;
+      if (Type legalOutputType = converter->convertType(outputType))
+        outputType = legalOutputType;
+
+      newArg = buildUnresolvedArgumentMaterialization(
+          rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
+          converter, unresolvedMaterializations);
+    }
+
+    mapping.map(origArg, newArg);
+    argReplacements.push_back(origArg);
+    argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
+  }
+
+  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
+                                            converter);
+  return newBlock;
+}
+
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -2601,8 +2534,11 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
     });
     return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
   };
-  return rewriterImpl.argConverter.materializeLiveConversions(
-      rewriterImpl.mapping, rewriter, findLiveUser);
+  for (auto &r : rewriterImpl.rewrites)
+    if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get()))
+      if (failed(rewrite->materializeLiveConversions(findLiveUser)))
+        return failure();
+  return success();
 }
 
 /// Replace the results of a materialization operation with the given values.



More information about the llvm-branch-commits mailing list