[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Turn block type conversion into `IRRewrite` (PR #81756)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Feb 14 08:25:55 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure).

Until now, the signature conversion of a block was only a "partial" IR rewrite. Rollbacks were triggered via `BlockTypeConversionRewrite::rollback`, but there was no `BlockTypeConversionRewrite::commit` equivalent.

Overview of changes:
* Remove `ArgConverter`, an internal helper class that kept track of all block type conversions. There is now a separate `BlockTypeConversionRewrite` for each block type conversion.
* No more special handling for block type conversions. They are now normal "IR rewrites", just like "block creation" or "block movement". In particular, trigger "commits" of block type conversion via `BlockTypeConversionRewrite::commit`.
* Remove `ArgConverter::notifyOpRemoved`. This function was used to inform the `ArgConverter` that an operation was erased, to prevent a double-free of operations in certain situations. It would be unpractical to add a `notifyOpRemoved` API to `IRRewrite`. Instead, erasing ops/block should go through a new `SingleEraseRewriter` (that is owned by the `ConversionPatternRewriterImpl`) if there is chance of double-free. This rewriter ignores `eraseOp`/`eraseBlock` if the op/block was already freed.


---

Patch is 40.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81756.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+376-434) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 67b076b295eae8..b2baa88879b6e9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,12 +154,13 @@ namespace {
 struct RewriterState {
   RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
                 unsigned numReplacements, unsigned numArgReplacements,
-                unsigned numRewrites, unsigned numIgnoredOperations)
+                unsigned numRewrites, unsigned numIgnoredOperations,
+                unsigned numErased)
       : numCreatedOps(numCreatedOps),
         numUnresolvedMaterializations(numUnresolvedMaterializations),
         numReplacements(numReplacements),
         numArgReplacements(numArgReplacements), numRewrites(numRewrites),
-        numIgnoredOperations(numIgnoredOperations) {}
+        numIgnoredOperations(numIgnoredOperations), numErased(numErased) {}
 
   /// The current number of created operations.
   unsigned numCreatedOps;
@@ -178,6 +179,9 @@ struct RewriterState {
 
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
+
+  /// The current number of erased operations/blocks.
+  unsigned numErased;
 };
 
 //===----------------------------------------------------------------------===//
@@ -292,374 +296,6 @@ static Value buildUnresolvedTargetMaterialization(
       outputType, outputType, converter, unresolvedMaterializations);
 }
 
-//===----------------------------------------------------------------------===//
-// ArgConverter
-//===----------------------------------------------------------------------===//
-namespace {
-/// This class provides a simple interface for converting the types of block
-/// arguments. This is done by creating a new block that contains the new legal
-/// types and extracting the block that contains the old illegal types to allow
-/// for undoing pending rewrites in the case of failure.
-struct ArgConverter {
-  ArgConverter(
-      PatternRewriter &rewriter,
-      SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations)
-      : rewriter(rewriter),
-        unresolvedMaterializations(unresolvedMaterializations) {}
-
-  /// This structure contains the information pertaining to an argument that has
-  /// been converted.
-  struct ConvertedArgInfo {
-    ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
-                     Value castValue = nullptr)
-        : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
-    /// The start index of in the new argument list that contains arguments that
-    /// replace the original.
-    unsigned newArgIdx;
-
-    /// The number of arguments that replaced the original argument.
-    unsigned newArgSize;
-
-    /// The cast value that was created to cast from the new arguments to the
-    /// old. This only used if 'newArgSize' > 1.
-    Value castValue;
-  };
-
-  /// This structure contains information pertaining to a block that has had its
-  /// signature converted.
-  struct ConvertedBlockInfo {
-    ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter)
-        : origBlock(origBlock), converter(converter) {}
-
-    /// The original block that was requested to have its signature converted.
-    Block *origBlock;
-
-    /// The conversion information for each of the arguments. The information is
-    /// std::nullopt if the argument was dropped during conversion.
-    SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
-    /// The type converter used to convert the arguments.
-    const TypeConverter *converter;
-  };
-
-  //===--------------------------------------------------------------------===//
-  // Rewrite Application
-  //===--------------------------------------------------------------------===//
-
-  /// Erase any rewrites registered for the blocks within the given operation
-  /// which is about to be removed. This merely drops the rewrites without
-  /// undoing them.
-  void notifyOpRemoved(Operation *op);
-
-  /// Cleanup and undo any generated conversions for the arguments of block.
-  /// This method replaces the new block with the original, reverting the IR to
-  /// its original state.
-  void discardRewrites(Block *block);
-
-  /// Fully replace uses of the old arguments with the new.
-  void applyRewrites(ConversionValueMapping &mapping);
-
-  /// Materialize any necessary conversions for converted arguments that have
-  /// live users, using the provided `findLiveUser` to search for a user that
-  /// survives the conversion process.
-  LogicalResult
-  materializeLiveConversions(ConversionValueMapping &mapping,
-                             OpBuilder &builder,
-                             function_ref<Operation *(Value)> findLiveUser);
-
-  //===--------------------------------------------------------------------===//
-  // Conversion
-  //===--------------------------------------------------------------------===//
-
-  /// Attempt to convert the signature of the given block, if successful a new
-  /// block is returned containing the new arguments. Returns `block` if it did
-  /// not require conversion.
-  FailureOr<Block *>
-  convertSignature(Block *block, const TypeConverter *converter,
-                   ConversionValueMapping &mapping,
-                   SmallVectorImpl<BlockArgument> &argReplacements);
-
-  /// Apply the given signature conversion on the given block. The new block
-  /// containing the updated signature is returned. If no conversions were
-  /// necessary, e.g. if the block has no arguments, `block` is returned.
-  /// `converter` is used to generate any necessary cast operations that
-  /// translate between the origin argument types and those specified in the
-  /// signature conversion.
-  Block *applySignatureConversion(
-      Block *block, const TypeConverter *converter,
-      TypeConverter::SignatureConversion &signatureConversion,
-      ConversionValueMapping &mapping,
-      SmallVectorImpl<BlockArgument> &argReplacements);
-
-  /// A collection of blocks that have had their arguments converted. This is a
-  /// map from the new replacement block, back to the original block.
-  llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
-
-  /// The pattern rewriter to use when materializing conversions.
-  PatternRewriter &rewriter;
-
-  /// An ordered set of unresolved materializations during conversion.
-  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations;
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// Rewrite Application
-
-void ArgConverter::notifyOpRemoved(Operation *op) {
-  if (conversionInfo.empty())
-    return;
-
-  for (Region &region : op->getRegions()) {
-    for (Block &block : region) {
-      // Drop any rewrites from within.
-      for (Operation &nestedOp : block)
-        if (nestedOp.getNumRegions())
-          notifyOpRemoved(&nestedOp);
-
-      // Check if this block was converted.
-      auto it = conversionInfo.find(&block);
-      if (it == conversionInfo.end())
-        continue;
-
-      // Drop all uses of the original arguments and delete the original block.
-      Block *origBlock = it->second.origBlock;
-      for (BlockArgument arg : origBlock->getArguments())
-        arg.dropAllUses();
-      conversionInfo.erase(it);
-    }
-  }
-}
-
-void ArgConverter::discardRewrites(Block *block) {
-  auto it = conversionInfo.find(block);
-  if (it == conversionInfo.end())
-    return;
-  Block *origBlock = it->second.origBlock;
-
-  // Drop all uses of the new block arguments and replace uses of the new block.
-  for (int i = block->getNumArguments() - 1; i >= 0; --i)
-    block->getArgument(i).dropAllUses();
-  block->replaceAllUsesWith(origBlock);
-
-  // Move the operations back the original block, move the original block back
-  // into its original location and the delete the new block.
-  origBlock->getOperations().splice(origBlock->end(), block->getOperations());
-  block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
-  block->erase();
-
-  conversionInfo.erase(it);
-}
-
-void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
-  for (auto &info : conversionInfo) {
-    ConvertedBlockInfo &blockInfo = info.second;
-    Block *origBlock = blockInfo.origBlock;
-
-    // Process the remapping for each of the original arguments.
-    for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
-      std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
-      BlockArgument origArg = origBlock->getArgument(i);
-
-      // Handle the case of a 1->0 value mapping.
-      if (!argInfo) {
-        if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType()))
-          origArg.replaceAllUsesWith(newArg);
-        continue;
-      }
-
-      // Otherwise this is a 1->1+ value mapping.
-      Value castValue = argInfo->castValue;
-      assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
-      // If the argument is still used, replace it with the generated cast.
-      if (!origArg.use_empty()) {
-        origArg.replaceAllUsesWith(
-            mapping.lookupOrDefault(castValue, origArg.getType()));
-      }
-    }
-
-    delete origBlock;
-    blockInfo.origBlock = nullptr;
-  }
-}
-
-LogicalResult ArgConverter::materializeLiveConversions(
-    ConversionValueMapping &mapping, OpBuilder &builder,
-    function_ref<Operation *(Value)> findLiveUser) {
-  for (auto &info : conversionInfo) {
-    Block *newBlock = info.first;
-    ConvertedBlockInfo &blockInfo = info.second;
-    Block *origBlock = blockInfo.origBlock;
-
-    // Process the remapping for each of the original arguments.
-    for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
-      // If the type of this argument changed and the argument is still live, we
-      // need to materialize a conversion.
-      BlockArgument origArg = origBlock->getArgument(i);
-      if (mapping.lookupOrNull(origArg, origArg.getType()))
-        continue;
-      Operation *liveUser = findLiveUser(origArg);
-      if (!liveUser)
-        continue;
-
-      Value replacementValue = mapping.lookupOrDefault(origArg);
-      bool isDroppedArg = replacementValue == origArg;
-      if (isDroppedArg)
-        rewriter.setInsertionPointToStart(newBlock);
-      else
-        rewriter.setInsertionPointAfterValue(replacementValue);
-      Value newArg;
-      if (blockInfo.converter) {
-        newArg = blockInfo.converter->materializeSourceConversion(
-            rewriter, origArg.getLoc(), origArg.getType(),
-            isDroppedArg ? ValueRange() : ValueRange(replacementValue));
-        assert((!newArg || newArg.getType() == origArg.getType()) &&
-               "materialization hook did not provide a value of the expected "
-               "type");
-      }
-      if (!newArg) {
-        InFlightDiagnostic diag =
-            emitError(origArg.getLoc())
-            << "failed to materialize conversion for block argument #" << i
-            << " that remained live after conversion, type was "
-            << origArg.getType();
-        if (!isDroppedArg)
-          diag << ", with target type " << replacementValue.getType();
-        diag.attachNote(liveUser->getLoc())
-            << "see existing live user here: " << *liveUser;
-        return failure();
-      }
-      mapping.map(origArg, newArg);
-    }
-  }
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Conversion
-
-FailureOr<Block *> ArgConverter::convertSignature(
-    Block *block, const TypeConverter *converter,
-    ConversionValueMapping &mapping,
-    SmallVectorImpl<BlockArgument> &argReplacements) {
-  // Check if the block was already converted.
-  // * If the block is mapped in `conversionInfo`, it is a converted block.
-  // * If the block is detached, conservatively assume that it is going to be
-  //   deleted; it is likely the old block (before it was converted).
-  if (conversionInfo.count(block) || !block->getParent())
-    return block;
-  // If a converter wasn't provided, and the block wasn't already converted,
-  // there is nothing we can do.
-  if (!converter)
-    return failure();
-
-  // Try to convert the signature for the block with the provided converter.
-  if (auto conversion = converter->convertBlockSignature(block))
-    return applySignatureConversion(block, converter, *conversion, mapping,
-                                    argReplacements);
-  return failure();
-}
-
-Block *ArgConverter::applySignatureConversion(
-    Block *block, const TypeConverter *converter,
-    TypeConverter::SignatureConversion &signatureConversion,
-    ConversionValueMapping &mapping,
-    SmallVectorImpl<BlockArgument> &argReplacements) {
-  // If no arguments are being changed or added, there is nothing to do.
-  unsigned origArgCount = block->getNumArguments();
-  auto convertedTypes = signatureConversion.getConvertedTypes();
-  if (origArgCount == 0 && convertedTypes.empty())
-    return block;
-
-  // Split the block at the beginning to get a new block to use for the updated
-  // signature.
-  Block *newBlock = block->splitBlock(block->begin());
-  block->replaceAllUsesWith(newBlock);
-  // Unlink the block, but do not erase it yet, so that the change can be rolled
-  // back.
-  block->getParent()->getBlocks().remove(block);
-
-  // Map all new arguments to the location of the argument they originate from.
-  SmallVector<Location> newLocs(convertedTypes.size(),
-                                rewriter.getUnknownLoc());
-  for (unsigned i = 0; i < origArgCount; ++i) {
-    auto inputMap = signatureConversion.getInputMapping(i);
-    if (!inputMap || inputMap->replacementValue)
-      continue;
-    Location origLoc = block->getArgument(i).getLoc();
-    for (unsigned j = 0; j < inputMap->size; ++j)
-      newLocs[inputMap->inputNo + j] = origLoc;
-  }
-
-  SmallVector<Value, 4> newArgRange(
-      newBlock->addArguments(convertedTypes, newLocs));
-  ArrayRef<Value> newArgs(newArgRange);
-
-  // Remap each of the original arguments as determined by the signature
-  // conversion.
-  ConvertedBlockInfo info(block, converter);
-  info.argInfo.resize(origArgCount);
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(newBlock);
-  for (unsigned i = 0; i != origArgCount; ++i) {
-    auto inputMap = signatureConversion.getInputMapping(i);
-    if (!inputMap)
-      continue;
-    BlockArgument origArg = block->getArgument(i);
-
-    // If inputMap->replacementValue is not nullptr, then the argument is
-    // dropped and a replacement value is provided to be the remappedValue.
-    if (inputMap->replacementValue) {
-      assert(inputMap->size == 0 &&
-             "invalid to provide a replacement value when the argument isn't "
-             "dropped");
-      mapping.map(origArg, inputMap->replacementValue);
-      argReplacements.push_back(origArg);
-      continue;
-    }
-
-    // Otherwise, this is a 1->1+ mapping.
-    auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
-    Value newArg;
-
-    // If this is a 1->1 mapping and the types of new and replacement arguments
-    // match (i.e. it's an identity map), then the argument is mapped to its
-    // original type.
-    // FIXME: We simply pass through the replacement argument if there wasn't a
-    // converter, which isn't great as it allows implicit type conversions to
-    // appear. We should properly restructure this code to handle cases where a
-    // converter isn't provided and also to properly handle the case where an
-    // argument materialization is actually a temporary source materialization
-    // (e.g. in the case of 1->N).
-    if (replArgs.size() == 1 &&
-        (!converter || replArgs[0].getType() == origArg.getType())) {
-      newArg = replArgs.front();
-    } else {
-      Type origOutputType = origArg.getType();
-
-      // Legalize the argument output type.
-      Type outputType = origOutputType;
-      if (Type legalOutputType = converter->convertType(outputType))
-        outputType = legalOutputType;
-
-      newArg = buildUnresolvedArgumentMaterialization(
-          rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
-          converter, unresolvedMaterializations);
-    }
-
-    mapping.map(origArg, newArg);
-    argReplacements.push_back(origArg);
-    info.argInfo[i] =
-        ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
-  }
-
-  conversionInfo.insert({newBlock, std::move(info)});
-  return newBlock;
-}
-
 //===----------------------------------------------------------------------===//
 // IR rewrites
 //===----------------------------------------------------------------------===//
@@ -694,6 +330,12 @@ class IRRewrite {
   /// Commit the rewrite.
   virtual void commit() {}
 
+  /// Erase the given op (unless it was already erased).
+  void eraseOp(Operation *op);
+
+  /// Erase the given block (unless it was already erased).
+  void eraseBlock(Block *block);
+
   Kind getKind() const { return kind; }
 
   static bool classof(const IRRewrite *rewrite) { return true; }
@@ -744,8 +386,7 @@ class CreateBlockRewrite : public BlockRewrite {
     auto &blockOps = block->getOperations();
     while (!blockOps.empty())
       blockOps.remove(blockOps.begin());
-    block->dropAllDefinedValueUses();
-    block->erase();
+    eraseBlock(block);
   }
 };
 
@@ -881,8 +522,7 @@ class SplitBlockRewrite : public BlockRewrite {
     // Merge back the block that was split out.
     originalBlock->getOperations().splice(originalBlock->end(),
                                           block->getOperations());
-    block->dropAllDefinedValueUses();
-    block->erase();
+    eraseBlock(block);
   }
 
 private:
@@ -890,20 +530,59 @@ class SplitBlockRewrite : public BlockRewrite {
   Block *originalBlock;
 };
 
+/// This structure contains the information pertaining to an argument that has
+/// been converted.
+struct ConvertedArgInfo {
+  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
+                   Value castValue = nullptr)
+      : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
+
+  /// The start index of in the new argument list that contains arguments that
+  /// replace the original.
+  unsigned newArgIdx;
+
+  /// The number of arguments that replaced the original argument.
+  unsigned newArgSize;
+
+  /// The cast value that was created to cast from the new arguments to the
+  /// old. This only used if 'newArgSize' > 1.
+  Value castValue;
+};
+
 /// Block type conversion. This rewrite is partially reflected in the IR.
 class BlockTypeConversionRewrite : public BlockRewrite {
 public:
-  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                             Block *block)
-      : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {}
+  BlockTypeConversionRewrite(
+      ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+      Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
+      const TypeConverter *converter)
+      : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
+        origBlock(origBlock), argInfo(argInfo), converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::BlockTypeConversion;
   }
 
-  // TODO: Block type conversions are currently committed in
-  // `ArgConverter::applyRewrites`. This should be done in the "commit" method.
+  /// Materialize any necessary conversions for converted arguments that have
+  /// live users, using the provided `findLiveUser` to search for a user that
+  /// survives the conversion process.
+  LogicalResult
+  materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+
+  void commit() override;
+
   void rollback() override;
+
+private:
+  /// The original block that was requested to have its signature converted.
+  Block *origBlock;
+
+  /// The conversion information for each of the arguments. The information is
+  /// std::nullopt if the argument was dropped during conversion.
+  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+
+  /// The type converter used to convert the arguments.
+  const TypeConverter...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/81756


More information about the llvm-branch-commits mailing list