[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Turn block type conversion into `IRRewrite` (PR #81756)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Feb 14 08:25:56 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure).
Until now, the signature conversion of a block was only a "partial" IR rewrite. Rollbacks were triggered via `BlockTypeConversionRewrite::rollback`, but there was no `BlockTypeConversionRewrite::commit` equivalent.
Overview of changes:
* Remove `ArgConverter`, an internal helper class that kept track of all block type conversions. There is now a separate `BlockTypeConversionRewrite` for each block type conversion.
* No more special handling for block type conversions. They are now normal "IR rewrites", just like "block creation" or "block movement". In particular, trigger "commits" of block type conversion via `BlockTypeConversionRewrite::commit`.
* Remove `ArgConverter::notifyOpRemoved`. This function was used to inform the `ArgConverter` that an operation was erased, to prevent a double-free of operations in certain situations. It would be unpractical to add a `notifyOpRemoved` API to `IRRewrite`. Instead, erasing ops/block should go through a new `SingleEraseRewriter` (that is owned by the `ConversionPatternRewriterImpl`) if there is chance of double-free. This rewriter ignores `eraseOp`/`eraseBlock` if the op/block was already freed.
---
Patch is 40.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81756.diff
1 Files Affected:
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+376-434)
``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 67b076b295eae8..b2baa88879b6e9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,12 +154,13 @@ namespace {
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
unsigned numReplacements, unsigned numArgReplacements,
- unsigned numRewrites, unsigned numIgnoredOperations)
+ unsigned numRewrites, unsigned numIgnoredOperations,
+ unsigned numErased)
: numCreatedOps(numCreatedOps),
numUnresolvedMaterializations(numUnresolvedMaterializations),
numReplacements(numReplacements),
numArgReplacements(numArgReplacements), numRewrites(numRewrites),
- numIgnoredOperations(numIgnoredOperations) {}
+ numIgnoredOperations(numIgnoredOperations), numErased(numErased) {}
/// The current number of created operations.
unsigned numCreatedOps;
@@ -178,6 +179,9 @@ struct RewriterState {
/// The current number of ignored operations.
unsigned numIgnoredOperations;
+
+ /// The current number of erased operations/blocks.
+ unsigned numErased;
};
//===----------------------------------------------------------------------===//
@@ -292,374 +296,6 @@ static Value buildUnresolvedTargetMaterialization(
outputType, outputType, converter, unresolvedMaterializations);
}
-//===----------------------------------------------------------------------===//
-// ArgConverter
-//===----------------------------------------------------------------------===//
-namespace {
-/// This class provides a simple interface for converting the types of block
-/// arguments. This is done by creating a new block that contains the new legal
-/// types and extracting the block that contains the old illegal types to allow
-/// for undoing pending rewrites in the case of failure.
-struct ArgConverter {
- ArgConverter(
- PatternRewriter &rewriter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations)
- : rewriter(rewriter),
- unresolvedMaterializations(unresolvedMaterializations) {}
-
- /// This structure contains the information pertaining to an argument that has
- /// been converted.
- struct ConvertedArgInfo {
- ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
- Value castValue = nullptr)
- : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
- /// The start index of in the new argument list that contains arguments that
- /// replace the original.
- unsigned newArgIdx;
-
- /// The number of arguments that replaced the original argument.
- unsigned newArgSize;
-
- /// The cast value that was created to cast from the new arguments to the
- /// old. This only used if 'newArgSize' > 1.
- Value castValue;
- };
-
- /// This structure contains information pertaining to a block that has had its
- /// signature converted.
- struct ConvertedBlockInfo {
- ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter)
- : origBlock(origBlock), converter(converter) {}
-
- /// The original block that was requested to have its signature converted.
- Block *origBlock;
-
- /// The conversion information for each of the arguments. The information is
- /// std::nullopt if the argument was dropped during conversion.
- SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
- /// The type converter used to convert the arguments.
- const TypeConverter *converter;
- };
-
- //===--------------------------------------------------------------------===//
- // Rewrite Application
- //===--------------------------------------------------------------------===//
-
- /// Erase any rewrites registered for the blocks within the given operation
- /// which is about to be removed. This merely drops the rewrites without
- /// undoing them.
- void notifyOpRemoved(Operation *op);
-
- /// Cleanup and undo any generated conversions for the arguments of block.
- /// This method replaces the new block with the original, reverting the IR to
- /// its original state.
- void discardRewrites(Block *block);
-
- /// Fully replace uses of the old arguments with the new.
- void applyRewrites(ConversionValueMapping &mapping);
-
- /// Materialize any necessary conversions for converted arguments that have
- /// live users, using the provided `findLiveUser` to search for a user that
- /// survives the conversion process.
- LogicalResult
- materializeLiveConversions(ConversionValueMapping &mapping,
- OpBuilder &builder,
- function_ref<Operation *(Value)> findLiveUser);
-
- //===--------------------------------------------------------------------===//
- // Conversion
- //===--------------------------------------------------------------------===//
-
- /// Attempt to convert the signature of the given block, if successful a new
- /// block is returned containing the new arguments. Returns `block` if it did
- /// not require conversion.
- FailureOr<Block *>
- convertSignature(Block *block, const TypeConverter *converter,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements);
-
- /// Apply the given signature conversion on the given block. The new block
- /// containing the updated signature is returned. If no conversions were
- /// necessary, e.g. if the block has no arguments, `block` is returned.
- /// `converter` is used to generate any necessary cast operations that
- /// translate between the origin argument types and those specified in the
- /// signature conversion.
- Block *applySignatureConversion(
- Block *block, const TypeConverter *converter,
- TypeConverter::SignatureConversion &signatureConversion,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements);
-
- /// A collection of blocks that have had their arguments converted. This is a
- /// map from the new replacement block, back to the original block.
- llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
-
- /// The pattern rewriter to use when materializing conversions.
- PatternRewriter &rewriter;
-
- /// An ordered set of unresolved materializations during conversion.
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations;
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// Rewrite Application
-
-void ArgConverter::notifyOpRemoved(Operation *op) {
- if (conversionInfo.empty())
- return;
-
- for (Region ®ion : op->getRegions()) {
- for (Block &block : region) {
- // Drop any rewrites from within.
- for (Operation &nestedOp : block)
- if (nestedOp.getNumRegions())
- notifyOpRemoved(&nestedOp);
-
- // Check if this block was converted.
- auto it = conversionInfo.find(&block);
- if (it == conversionInfo.end())
- continue;
-
- // Drop all uses of the original arguments and delete the original block.
- Block *origBlock = it->second.origBlock;
- for (BlockArgument arg : origBlock->getArguments())
- arg.dropAllUses();
- conversionInfo.erase(it);
- }
- }
-}
-
-void ArgConverter::discardRewrites(Block *block) {
- auto it = conversionInfo.find(block);
- if (it == conversionInfo.end())
- return;
- Block *origBlock = it->second.origBlock;
-
- // Drop all uses of the new block arguments and replace uses of the new block.
- for (int i = block->getNumArguments() - 1; i >= 0; --i)
- block->getArgument(i).dropAllUses();
- block->replaceAllUsesWith(origBlock);
-
- // Move the operations back the original block, move the original block back
- // into its original location and the delete the new block.
- origBlock->getOperations().splice(origBlock->end(), block->getOperations());
- block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
- block->erase();
-
- conversionInfo.erase(it);
-}
-
-void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
- for (auto &info : conversionInfo) {
- ConvertedBlockInfo &blockInfo = info.second;
- Block *origBlock = blockInfo.origBlock;
-
- // Process the remapping for each of the original arguments.
- for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
- std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
- BlockArgument origArg = origBlock->getArgument(i);
-
- // Handle the case of a 1->0 value mapping.
- if (!argInfo) {
- if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType()))
- origArg.replaceAllUsesWith(newArg);
- continue;
- }
-
- // Otherwise this is a 1->1+ value mapping.
- Value castValue = argInfo->castValue;
- assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
- // If the argument is still used, replace it with the generated cast.
- if (!origArg.use_empty()) {
- origArg.replaceAllUsesWith(
- mapping.lookupOrDefault(castValue, origArg.getType()));
- }
- }
-
- delete origBlock;
- blockInfo.origBlock = nullptr;
- }
-}
-
-LogicalResult ArgConverter::materializeLiveConversions(
- ConversionValueMapping &mapping, OpBuilder &builder,
- function_ref<Operation *(Value)> findLiveUser) {
- for (auto &info : conversionInfo) {
- Block *newBlock = info.first;
- ConvertedBlockInfo &blockInfo = info.second;
- Block *origBlock = blockInfo.origBlock;
-
- // Process the remapping for each of the original arguments.
- for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
- // If the type of this argument changed and the argument is still live, we
- // need to materialize a conversion.
- BlockArgument origArg = origBlock->getArgument(i);
- if (mapping.lookupOrNull(origArg, origArg.getType()))
- continue;
- Operation *liveUser = findLiveUser(origArg);
- if (!liveUser)
- continue;
-
- Value replacementValue = mapping.lookupOrDefault(origArg);
- bool isDroppedArg = replacementValue == origArg;
- if (isDroppedArg)
- rewriter.setInsertionPointToStart(newBlock);
- else
- rewriter.setInsertionPointAfterValue(replacementValue);
- Value newArg;
- if (blockInfo.converter) {
- newArg = blockInfo.converter->materializeSourceConversion(
- rewriter, origArg.getLoc(), origArg.getType(),
- isDroppedArg ? ValueRange() : ValueRange(replacementValue));
- assert((!newArg || newArg.getType() == origArg.getType()) &&
- "materialization hook did not provide a value of the expected "
- "type");
- }
- if (!newArg) {
- InFlightDiagnostic diag =
- emitError(origArg.getLoc())
- << "failed to materialize conversion for block argument #" << i
- << " that remained live after conversion, type was "
- << origArg.getType();
- if (!isDroppedArg)
- diag << ", with target type " << replacementValue.getType();
- diag.attachNote(liveUser->getLoc())
- << "see existing live user here: " << *liveUser;
- return failure();
- }
- mapping.map(origArg, newArg);
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Conversion
-
-FailureOr<Block *> ArgConverter::convertSignature(
- Block *block, const TypeConverter *converter,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements) {
- // Check if the block was already converted.
- // * If the block is mapped in `conversionInfo`, it is a converted block.
- // * If the block is detached, conservatively assume that it is going to be
- // deleted; it is likely the old block (before it was converted).
- if (conversionInfo.count(block) || !block->getParent())
- return block;
- // If a converter wasn't provided, and the block wasn't already converted,
- // there is nothing we can do.
- if (!converter)
- return failure();
-
- // Try to convert the signature for the block with the provided converter.
- if (auto conversion = converter->convertBlockSignature(block))
- return applySignatureConversion(block, converter, *conversion, mapping,
- argReplacements);
- return failure();
-}
-
-Block *ArgConverter::applySignatureConversion(
- Block *block, const TypeConverter *converter,
- TypeConverter::SignatureConversion &signatureConversion,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements) {
- // If no arguments are being changed or added, there is nothing to do.
- unsigned origArgCount = block->getNumArguments();
- auto convertedTypes = signatureConversion.getConvertedTypes();
- if (origArgCount == 0 && convertedTypes.empty())
- return block;
-
- // Split the block at the beginning to get a new block to use for the updated
- // signature.
- Block *newBlock = block->splitBlock(block->begin());
- block->replaceAllUsesWith(newBlock);
- // Unlink the block, but do not erase it yet, so that the change can be rolled
- // back.
- block->getParent()->getBlocks().remove(block);
-
- // Map all new arguments to the location of the argument they originate from.
- SmallVector<Location> newLocs(convertedTypes.size(),
- rewriter.getUnknownLoc());
- for (unsigned i = 0; i < origArgCount; ++i) {
- auto inputMap = signatureConversion.getInputMapping(i);
- if (!inputMap || inputMap->replacementValue)
- continue;
- Location origLoc = block->getArgument(i).getLoc();
- for (unsigned j = 0; j < inputMap->size; ++j)
- newLocs[inputMap->inputNo + j] = origLoc;
- }
-
- SmallVector<Value, 4> newArgRange(
- newBlock->addArguments(convertedTypes, newLocs));
- ArrayRef<Value> newArgs(newArgRange);
-
- // Remap each of the original arguments as determined by the signature
- // conversion.
- ConvertedBlockInfo info(block, converter);
- info.argInfo.resize(origArgCount);
-
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(newBlock);
- for (unsigned i = 0; i != origArgCount; ++i) {
- auto inputMap = signatureConversion.getInputMapping(i);
- if (!inputMap)
- continue;
- BlockArgument origArg = block->getArgument(i);
-
- // If inputMap->replacementValue is not nullptr, then the argument is
- // dropped and a replacement value is provided to be the remappedValue.
- if (inputMap->replacementValue) {
- assert(inputMap->size == 0 &&
- "invalid to provide a replacement value when the argument isn't "
- "dropped");
- mapping.map(origArg, inputMap->replacementValue);
- argReplacements.push_back(origArg);
- continue;
- }
-
- // Otherwise, this is a 1->1+ mapping.
- auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
- Value newArg;
-
- // If this is a 1->1 mapping and the types of new and replacement arguments
- // match (i.e. it's an identity map), then the argument is mapped to its
- // original type.
- // FIXME: We simply pass through the replacement argument if there wasn't a
- // converter, which isn't great as it allows implicit type conversions to
- // appear. We should properly restructure this code to handle cases where a
- // converter isn't provided and also to properly handle the case where an
- // argument materialization is actually a temporary source materialization
- // (e.g. in the case of 1->N).
- if (replArgs.size() == 1 &&
- (!converter || replArgs[0].getType() == origArg.getType())) {
- newArg = replArgs.front();
- } else {
- Type origOutputType = origArg.getType();
-
- // Legalize the argument output type.
- Type outputType = origOutputType;
- if (Type legalOutputType = converter->convertType(outputType))
- outputType = legalOutputType;
-
- newArg = buildUnresolvedArgumentMaterialization(
- rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
- converter, unresolvedMaterializations);
- }
-
- mapping.map(origArg, newArg);
- argReplacements.push_back(origArg);
- info.argInfo[i] =
- ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
- }
-
- conversionInfo.insert({newBlock, std::move(info)});
- return newBlock;
-}
-
//===----------------------------------------------------------------------===//
// IR rewrites
//===----------------------------------------------------------------------===//
@@ -694,6 +330,12 @@ class IRRewrite {
/// Commit the rewrite.
virtual void commit() {}
+ /// Erase the given op (unless it was already erased).
+ void eraseOp(Operation *op);
+
+ /// Erase the given block (unless it was already erased).
+ void eraseBlock(Block *block);
+
Kind getKind() const { return kind; }
static bool classof(const IRRewrite *rewrite) { return true; }
@@ -744,8 +386,7 @@ class CreateBlockRewrite : public BlockRewrite {
auto &blockOps = block->getOperations();
while (!blockOps.empty())
blockOps.remove(blockOps.begin());
- block->dropAllDefinedValueUses();
- block->erase();
+ eraseBlock(block);
}
};
@@ -881,8 +522,7 @@ class SplitBlockRewrite : public BlockRewrite {
// Merge back the block that was split out.
originalBlock->getOperations().splice(originalBlock->end(),
block->getOperations());
- block->dropAllDefinedValueUses();
- block->erase();
+ eraseBlock(block);
}
private:
@@ -890,20 +530,59 @@ class SplitBlockRewrite : public BlockRewrite {
Block *originalBlock;
};
+/// This structure contains the information pertaining to an argument that has
+/// been converted.
+struct ConvertedArgInfo {
+ ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
+ Value castValue = nullptr)
+ : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
+
+ /// The start index of in the new argument list that contains arguments that
+ /// replace the original.
+ unsigned newArgIdx;
+
+ /// The number of arguments that replaced the original argument.
+ unsigned newArgSize;
+
+ /// The cast value that was created to cast from the new arguments to the
+ /// old. This only used if 'newArgSize' > 1.
+ Value castValue;
+};
+
/// Block type conversion. This rewrite is partially reflected in the IR.
class BlockTypeConversionRewrite : public BlockRewrite {
public:
- BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block)
- : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {}
+ BlockTypeConversionRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+ Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
+ const TypeConverter *converter)
+ : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
+ origBlock(origBlock), argInfo(argInfo), converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
}
- // TODO: Block type conversions are currently committed in
- // `ArgConverter::applyRewrites`. This should be done in the "commit" method.
+ /// Materialize any necessary conversions for converted arguments that have
+ /// live users, using the provided `findLiveUser` to search for a user that
+ /// survives the conversion process.
+ LogicalResult
+ materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+
+ void commit() override;
+
void rollback() override;
+
+private:
+ /// The original block that was requested to have its signature converted.
+ Block *origBlock;
+
+ /// The conversion information for each of the arguments. The information is
+ /// std::nullopt if the argument was dropped during conversion.
+ SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+
+ /// The type converter used to convert the arguments.
+ const TypeConverter...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/81756
More information about the llvm-branch-commits
mailing list