[Mlir-commits] [mlir] [mlir][Transforms][NFC] Turn block type conversion into `IRRewrite` (PR #81756)
Matthias Springer
llvmlistbot at llvm.org
Thu Feb 22 00:57:41 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81756
>From 04eb7bdd6cc857119f82e0fee253e464b589b725 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 21 Feb 2024 16:48:48 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Turn block type convertion into
`IRRewrite`
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure).
Until now, the signature conversion of a block was only a "partial" IR rewrite. Rollbacks were triggered via `BlockTypeConversionRewrite::rollback`, but there was no `BlockTypeConversionRewrite::commit` equivalent.
Overview of changes:
* Remove `ArgConverter`, an internal helper class that kept track of all block type conversions. There is now a separate `BlockTypeConversionRewrite` for each block type conversion.
* No more special handling for block type conversions. They are now normal "IR rewrites", just like "block creation" or "block movement". In particular, trigger "commits" of block type conversion via `BlockTypeConversionRewrite::commit`.
* Remove `ArgConverter::notifyOpRemoved`. This function was used to inform the `ArgConverter` that an operation was erased, to prevent a double-free of operations in certain situations. It would be unpractical to add a `notifyOpRemoved` API to `IRRewrite`. Instead, erasing ops/block should go through a new `SingleEraseRewriter` (that is owned by the `ConversionPatternRewriterImpl`) if there is chance of double-free. This rewriter ignores `eraseOp`/`eraseBlock` if the op/block was already freed.
BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
.../Transforms/Utils/DialectConversion.cpp | 794 ++++++++----------
1 file changed, 364 insertions(+), 430 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index afdd31a748c8c4..db41b9f19e7e8d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,12 +154,13 @@ namespace {
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
unsigned numReplacements, unsigned numArgReplacements,
- unsigned numRewrites, unsigned numIgnoredOperations)
+ unsigned numRewrites, unsigned numIgnoredOperations,
+ unsigned numErased)
: numCreatedOps(numCreatedOps),
numUnresolvedMaterializations(numUnresolvedMaterializations),
numReplacements(numReplacements),
numArgReplacements(numArgReplacements), numRewrites(numRewrites),
- numIgnoredOperations(numIgnoredOperations) {}
+ numIgnoredOperations(numIgnoredOperations), numErased(numErased) {}
/// The current number of created operations.
unsigned numCreatedOps;
@@ -178,6 +179,9 @@ struct RewriterState {
/// The current number of ignored operations.
unsigned numIgnoredOperations;
+
+ /// The current number of erased operations/blocks.
+ unsigned numErased;
};
//===----------------------------------------------------------------------===//
@@ -292,370 +296,6 @@ static Value buildUnresolvedTargetMaterialization(
outputType, outputType, converter, unresolvedMaterializations);
}
-//===----------------------------------------------------------------------===//
-// ArgConverter
-//===----------------------------------------------------------------------===//
-namespace {
-/// This class provides a simple interface for converting the types of block
-/// arguments. This is done by creating a new block that contains the new legal
-/// types and extracting the block that contains the old illegal types to allow
-/// for undoing pending rewrites in the case of failure.
-struct ArgConverter {
- ArgConverter(
- PatternRewriter &rewriter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations)
- : rewriter(rewriter),
- unresolvedMaterializations(unresolvedMaterializations) {}
-
- /// This structure contains the information pertaining to an argument that has
- /// been converted.
- struct ConvertedArgInfo {
- ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
- Value castValue = nullptr)
- : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
- /// The start index of in the new argument list that contains arguments that
- /// replace the original.
- unsigned newArgIdx;
-
- /// The number of arguments that replaced the original argument.
- unsigned newArgSize;
-
- /// The cast value that was created to cast from the new arguments to the
- /// old. This only used if 'newArgSize' > 1.
- Value castValue;
- };
-
- /// This structure contains information pertaining to a block that has had its
- /// signature converted.
- struct ConvertedBlockInfo {
- ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter)
- : origBlock(origBlock), converter(converter) {}
-
- /// The original block that was requested to have its signature converted.
- Block *origBlock;
-
- /// The conversion information for each of the arguments. The information is
- /// std::nullopt if the argument was dropped during conversion.
- SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
- /// The type converter used to convert the arguments.
- const TypeConverter *converter;
- };
-
- //===--------------------------------------------------------------------===//
- // Rewrite Application
- //===--------------------------------------------------------------------===//
-
- /// Erase any rewrites registered for the blocks within the given operation
- /// which is about to be removed. This merely drops the rewrites without
- /// undoing them.
- void notifyOpRemoved(Operation *op);
-
- /// Cleanup and undo any generated conversions for the arguments of block.
- /// This method replaces the new block with the original, reverting the IR to
- /// its original state.
- void discardRewrites(Block *block);
-
- /// Fully replace uses of the old arguments with the new.
- void applyRewrites(ConversionValueMapping &mapping);
-
- /// Materialize any necessary conversions for converted arguments that have
- /// live users, using the provided `findLiveUser` to search for a user that
- /// survives the conversion process.
- LogicalResult
- materializeLiveConversions(ConversionValueMapping &mapping,
- OpBuilder &builder,
- function_ref<Operation *(Value)> findLiveUser);
-
- //===--------------------------------------------------------------------===//
- // Conversion
- //===--------------------------------------------------------------------===//
-
- /// Attempt to convert the signature of the given block, if successful a new
- /// block is returned containing the new arguments. Returns `block` if it did
- /// not require conversion.
- FailureOr<Block *>
- convertSignature(Block *block, const TypeConverter *converter,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements);
-
- /// Apply the given signature conversion on the given block. The new block
- /// containing the updated signature is returned. If no conversions were
- /// necessary, e.g. if the block has no arguments, `block` is returned.
- /// `converter` is used to generate any necessary cast operations that
- /// translate between the origin argument types and those specified in the
- /// signature conversion.
- Block *applySignatureConversion(
- Block *block, const TypeConverter *converter,
- TypeConverter::SignatureConversion &signatureConversion,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements);
-
- /// A collection of blocks that have had their arguments converted. This is a
- /// map from the new replacement block, back to the original block.
- llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
-
- /// The pattern rewriter to use when materializing conversions.
- PatternRewriter &rewriter;
-
- /// An ordered set of unresolved materializations during conversion.
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations;
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// Rewrite Application
-
-void ArgConverter::notifyOpRemoved(Operation *op) {
- if (conversionInfo.empty())
- return;
-
- for (Region ®ion : op->getRegions()) {
- for (Block &block : region) {
- // Drop any rewrites from within.
- for (Operation &nestedOp : block)
- if (nestedOp.getNumRegions())
- notifyOpRemoved(&nestedOp);
-
- // Check if this block was converted.
- auto *it = conversionInfo.find(&block);
- if (it == conversionInfo.end())
- continue;
-
- // Drop all uses of the original arguments and delete the original block.
- Block *origBlock = it->second.origBlock;
- for (BlockArgument arg : origBlock->getArguments())
- arg.dropAllUses();
- conversionInfo.erase(it);
- }
- }
-}
-
-void ArgConverter::discardRewrites(Block *block) {
- auto *it = conversionInfo.find(block);
- if (it == conversionInfo.end())
- return;
- Block *origBlock = it->second.origBlock;
-
- // Drop all uses of the new block arguments and replace uses of the new block.
- for (int i = block->getNumArguments() - 1; i >= 0; --i)
- block->getArgument(i).dropAllUses();
- block->replaceAllUsesWith(origBlock);
-
- // Move the operations back the original block, move the original block back
- // into its original location and the delete the new block.
- origBlock->getOperations().splice(origBlock->end(), block->getOperations());
- block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
- block->erase();
-
- conversionInfo.erase(it);
-}
-
-void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
- for (auto &info : conversionInfo) {
- ConvertedBlockInfo &blockInfo = info.second;
- Block *origBlock = blockInfo.origBlock;
-
- // Process the remapping for each of the original arguments.
- for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
- std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
- BlockArgument origArg = origBlock->getArgument(i);
-
- // Handle the case of a 1->0 value mapping.
- if (!argInfo) {
- if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType()))
- origArg.replaceAllUsesWith(newArg);
- continue;
- }
-
- // Otherwise this is a 1->1+ value mapping.
- Value castValue = argInfo->castValue;
- assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
- // If the argument is still used, replace it with the generated cast.
- if (!origArg.use_empty()) {
- origArg.replaceAllUsesWith(
- mapping.lookupOrDefault(castValue, origArg.getType()));
- }
- }
-
- delete origBlock;
- blockInfo.origBlock = nullptr;
- }
-}
-
-LogicalResult ArgConverter::materializeLiveConversions(
- ConversionValueMapping &mapping, OpBuilder &builder,
- function_ref<Operation *(Value)> findLiveUser) {
- for (auto &info : conversionInfo) {
- Block *newBlock = info.first;
- ConvertedBlockInfo &blockInfo = info.second;
- Block *origBlock = blockInfo.origBlock;
-
- // Process the remapping for each of the original arguments.
- for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
- // If the type of this argument changed and the argument is still live, we
- // need to materialize a conversion.
- BlockArgument origArg = origBlock->getArgument(i);
- if (mapping.lookupOrNull(origArg, origArg.getType()))
- continue;
- Operation *liveUser = findLiveUser(origArg);
- if (!liveUser)
- continue;
-
- Value replacementValue = mapping.lookupOrDefault(origArg);
- bool isDroppedArg = replacementValue == origArg;
- if (isDroppedArg)
- rewriter.setInsertionPointToStart(newBlock);
- else
- rewriter.setInsertionPointAfterValue(replacementValue);
- Value newArg;
- if (blockInfo.converter) {
- newArg = blockInfo.converter->materializeSourceConversion(
- rewriter, origArg.getLoc(), origArg.getType(),
- isDroppedArg ? ValueRange() : ValueRange(replacementValue));
- assert((!newArg || newArg.getType() == origArg.getType()) &&
- "materialization hook did not provide a value of the expected "
- "type");
- }
- if (!newArg) {
- InFlightDiagnostic diag =
- emitError(origArg.getLoc())
- << "failed to materialize conversion for block argument #" << i
- << " that remained live after conversion, type was "
- << origArg.getType();
- if (!isDroppedArg)
- diag << ", with target type " << replacementValue.getType();
- diag.attachNote(liveUser->getLoc())
- << "see existing live user here: " << *liveUser;
- return failure();
- }
- mapping.map(origArg, newArg);
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Conversion
-
-FailureOr<Block *> ArgConverter::convertSignature(
- Block *block, const TypeConverter *converter,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements) {
- assert(block->getParent() && "cannot convert signature of detached block");
-
- // If a converter wasn't provided, and the block wasn't already converted,
- // there is nothing we can do.
- if (!converter)
- return failure();
-
- // Try to convert the signature for the block with the provided converter.
- if (auto conversion = converter->convertBlockSignature(block))
- return applySignatureConversion(block, converter, *conversion, mapping,
- argReplacements);
- return failure();
-}
-
-Block *ArgConverter::applySignatureConversion(
- Block *block, const TypeConverter *converter,
- TypeConverter::SignatureConversion &signatureConversion,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements) {
- // If no arguments are being changed or added, there is nothing to do.
- unsigned origArgCount = block->getNumArguments();
- auto convertedTypes = signatureConversion.getConvertedTypes();
- if (llvm::equal(block->getArgumentTypes(), convertedTypes))
- return block;
-
- // Split the block at the beginning to get a new block to use for the updated
- // signature.
- Block *newBlock = block->splitBlock(block->begin());
- block->replaceAllUsesWith(newBlock);
- // Unlink the block, but do not erase it yet, so that the change can be rolled
- // back.
- block->getParent()->getBlocks().remove(block);
-
- // Map all new arguments to the location of the argument they originate from.
- SmallVector<Location> newLocs(convertedTypes.size(),
- rewriter.getUnknownLoc());
- for (unsigned i = 0; i < origArgCount; ++i) {
- auto inputMap = signatureConversion.getInputMapping(i);
- if (!inputMap || inputMap->replacementValue)
- continue;
- Location origLoc = block->getArgument(i).getLoc();
- for (unsigned j = 0; j < inputMap->size; ++j)
- newLocs[inputMap->inputNo + j] = origLoc;
- }
-
- SmallVector<Value, 4> newArgRange(
- newBlock->addArguments(convertedTypes, newLocs));
- ArrayRef<Value> newArgs(newArgRange);
-
- // Remap each of the original arguments as determined by the signature
- // conversion.
- ConvertedBlockInfo info(block, converter);
- info.argInfo.resize(origArgCount);
-
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(newBlock);
- for (unsigned i = 0; i != origArgCount; ++i) {
- auto inputMap = signatureConversion.getInputMapping(i);
- if (!inputMap)
- continue;
- BlockArgument origArg = block->getArgument(i);
-
- // If inputMap->replacementValue is not nullptr, then the argument is
- // dropped and a replacement value is provided to be the remappedValue.
- if (inputMap->replacementValue) {
- assert(inputMap->size == 0 &&
- "invalid to provide a replacement value when the argument isn't "
- "dropped");
- mapping.map(origArg, inputMap->replacementValue);
- argReplacements.push_back(origArg);
- continue;
- }
-
- // Otherwise, this is a 1->1+ mapping.
- auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
- Value newArg;
-
- // If this is a 1->1 mapping and the types of new and replacement arguments
- // match (i.e. it's an identity map), then the argument is mapped to its
- // original type.
- // FIXME: We simply pass through the replacement argument if there wasn't a
- // converter, which isn't great as it allows implicit type conversions to
- // appear. We should properly restructure this code to handle cases where a
- // converter isn't provided and also to properly handle the case where an
- // argument materialization is actually a temporary source materialization
- // (e.g. in the case of 1->N).
- if (replArgs.size() == 1 &&
- (!converter || replArgs[0].getType() == origArg.getType())) {
- newArg = replArgs.front();
- } else {
- Type origOutputType = origArg.getType();
-
- // Legalize the argument output type.
- Type outputType = origOutputType;
- if (Type legalOutputType = converter->convertType(outputType))
- outputType = legalOutputType;
-
- newArg = buildUnresolvedArgumentMaterialization(
- rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
- converter, unresolvedMaterializations);
- }
-
- mapping.map(origArg, newArg);
- argReplacements.push_back(origArg);
- info.argInfo[i] =
- ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
- }
-
- conversionInfo.insert({newBlock, std::move(info)});
- return newBlock;
-}
-
//===----------------------------------------------------------------------===//
// IR rewrites
//===----------------------------------------------------------------------===//
@@ -702,6 +342,12 @@ class IRRewrite {
IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
: kind(kind), rewriterImpl(rewriterImpl) {}
+ /// Erase the given op (unless it was already erased).
+ void eraseOp(Operation *op);
+
+ /// Erase the given block (unless it was already erased).
+ void eraseBlock(Block *block);
+
const Kind kind;
ConversionPatternRewriterImpl &rewriterImpl;
};
@@ -744,8 +390,7 @@ class CreateBlockRewrite : public BlockRewrite {
auto &blockOps = block->getOperations();
while (!blockOps.empty())
blockOps.remove(blockOps.begin());
- block->dropAllDefinedValueUses();
- block->erase();
+ eraseBlock(block);
}
};
@@ -881,8 +526,7 @@ class SplitBlockRewrite : public BlockRewrite {
// Merge back the block that was split out.
originalBlock->getOperations().splice(originalBlock->end(),
block->getOperations());
- block->dropAllDefinedValueUses();
- block->erase();
+ eraseBlock(block);
}
private:
@@ -890,20 +534,59 @@ class SplitBlockRewrite : public BlockRewrite {
Block *originalBlock;
};
+/// This structure contains the information pertaining to an argument that has
+/// been converted.
+struct ConvertedArgInfo {
+ ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
+ Value castValue = nullptr)
+ : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
+
+ /// The start index of in the new argument list that contains arguments that
+ /// replace the original.
+ unsigned newArgIdx;
+
+ /// The number of arguments that replaced the original argument.
+ unsigned newArgSize;
+
+ /// The cast value that was created to cast from the new arguments to the
+ /// old. This only used if 'newArgSize' > 1.
+ Value castValue;
+};
+
/// Block type conversion. This rewrite is partially reflected in the IR.
class BlockTypeConversionRewrite : public BlockRewrite {
public:
- BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block)
- : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {}
+ BlockTypeConversionRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+ Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
+ const TypeConverter *converter)
+ : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
+ origBlock(origBlock), argInfo(argInfo), converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
}
- // TODO: Block type conversions are currently committed in
- // `ArgConverter::applyRewrites`. This should be done in the "commit" method.
+ /// Materialize any necessary conversions for converted arguments that have
+ /// live users, using the provided `findLiveUser` to search for a user that
+ /// survives the conversion process.
+ LogicalResult
+ materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+
+ void commit() override;
+
void rollback() override;
+
+private:
+ /// The original block that was requested to have its signature converted.
+ Block *origBlock;
+
+ /// The conversion information for each of the arguments. The information is
+ /// std::nullopt if the argument was dropped during conversion.
+ SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+
+ /// The type converter used to convert the arguments.
+ const TypeConverter *converter;
};
/// An operation rewrite.
@@ -949,8 +632,8 @@ class MoveOperationRewrite : public OperationRewrite {
// The block in which this operation was previously contained.
Block *block;
- // The original successor of this operation before it was moved. "nullptr" if
- // this operation was the only operation in the region.
+ // The original successor of this operation before it was moved. "nullptr"
+ // if this operation was the only operation in the region.
Operation *insertBeforeOp;
};
@@ -1027,6 +710,26 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
});
}
+/// Find the single rewrite object of the specified type and block among the
+/// given rewrites. In debug mode, asserts that there is mo more than one such
+/// object. Return "nullptr" if no object was found.
+template <typename RewriteTy, typename R>
+static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
+ RewriteTy *result = nullptr;
+ for (auto &rewrite : rewrites) {
+ auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
+ if (rewriteTy && rewriteTy->getBlock() == block) {
+#ifndef NDEBUG
+ assert(!result && "expected single matching rewrite");
+ result = rewriteTy;
+#else
+ return rewriteTy;
+#endif // NDEBUG
+ }
+ }
+ return result;
+}
+
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
@@ -1034,7 +737,7 @@ namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
- : argConverter(rewriter, unresolvedMaterializations),
+ : rewriter(rewriter), eraseRewriter(rewriter.getContext()),
notifyCallback(nullptr) {}
/// Cleanup and destroy any generated rewrite operations. This method is
@@ -1084,15 +787,33 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// removes them from being considered for legalization.
void markNestedOpsIgnored(Operation *op);
+ /// Detach any operations nested in the given operation from their parent
+ /// blocks, and erase the given operation. This can be used when the nested
+ /// operations are scheduled for erasure themselves, so deleting the regions
+ /// of the given operation together with their content would result in
+ /// double-free. This happens, for example, when rolling back op creation in
+ /// the reverse order and if the nested ops were created before the parent op.
+ /// This function does not need to collect nested ops recursively because it
+ /// is expected to also be called for each nested op when it is about to be
+ /// deleted.
+ void detachNestedAndErase(Operation *op);
+
//===--------------------------------------------------------------------===//
// Type Conversion
//===--------------------------------------------------------------------===//
- /// Convert the signature of the given block.
+ /// Attempt to convert the signature of the given block, if successful a new
+ /// block is returned containing the new arguments. Returns `block` if it did
+ /// not require conversion.
FailureOr<Block *> convertBlockSignature(
Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion = nullptr);
+ /// Convert the types of non-entry block arguments within the given region.
+ LogicalResult convertNonEntryRegionTypes(
+ Region *region, const TypeConverter &converter,
+ ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
+
/// Apply a signature conversion on the given region, using `converter` for
/// materializations if not null.
Block *
@@ -1105,10 +826,15 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
convertRegionTypes(Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);
- /// Convert the types of non-entry block arguments within the given region.
- LogicalResult convertNonEntryRegionTypes(
- Region *region, const TypeConverter &converter,
- ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
+ /// Apply the given signature conversion on the given block. The new block
+ /// containing the updated signature is returned. If no conversions were
+ /// necessary, e.g. if the block has no arguments, `block` is returned.
+ /// `converter` is used to generate any necessary cast operations that
+ /// translate between the origin argument types and those specified in the
+ /// signature conversion.
+ Block *applySignatureConversion(
+ Block *block, const TypeConverter *converter,
+ TypeConverter::SignatureConversion &signatureConversion);
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -1140,17 +866,54 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;
+ //===--------------------------------------------------------------------===//
+ // IR Erasure
+ //===--------------------------------------------------------------------===//
+
+ /// A rewriter that keeps track of erased ops and blocks. It ensures that no
+ /// operation or block is erased multiple times. This rewriter assumes that
+ /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
+ struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
+ public:
+ SingleEraseRewriter(MLIRContext *context)
+ : RewriterBase(context, /*listener=*/this) {}
+
+ /// Erase the given op (unless it was already erased).
+ void eraseOp(Operation *op) override {
+ if (erased.contains(op))
+ return;
+ op->dropAllUses();
+ RewriterBase::eraseOp(op);
+ }
+
+ /// Erase the given block (unless it was already erased).
+ void eraseBlock(Block *block) override {
+ if (erased.contains(block))
+ return;
+ block->dropAllDefinedValueUses();
+ RewriterBase::eraseBlock(block);
+ }
+
+ void notifyOperationErased(Operation *op) override { erased.insert(op); }
+ void notifyBlockErased(Block *block) override { erased.insert(block); }
+
+ /// Pointers to all erased operations and blocks.
+ SetVector<void *> erased;
+ };
+
//===--------------------------------------------------------------------===//
// State
//===--------------------------------------------------------------------===//
+ PatternRewriter &rewriter;
+
+ /// This rewriter must be used for erasing ops/blocks.
+ SingleEraseRewriter eraseRewriter;
+
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
ConversionValueMapping mapping;
- /// Utility used to convert block arguments.
- ArgConverter argConverter;
-
/// Ordered vector of all of the newly created operations during conversion.
SmallVector<Operation *> createdOps;
@@ -1207,20 +970,100 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
} // namespace detail
} // namespace mlir
+void IRRewrite::eraseOp(Operation *op) {
+ rewriterImpl.eraseRewriter.eraseOp(op);
+}
+
+void IRRewrite::eraseBlock(Block *block) {
+ rewriterImpl.eraseRewriter.eraseBlock(block);
+}
+
+void BlockTypeConversionRewrite::commit() {
+ // Process the remapping for each of the original arguments.
+ for (auto [origArg, info] :
+ llvm::zip_equal(origBlock->getArguments(), argInfo)) {
+ // Handle the case of a 1->0 value mapping.
+ if (!info) {
+ if (Value newArg =
+ rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+ origArg.replaceAllUsesWith(newArg);
+ continue;
+ }
+
+ // Otherwise this is a 1->1+ value mapping.
+ Value castValue = info->castValue;
+ assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
+
+ // If the argument is still used, replace it with the generated cast.
+ if (!origArg.use_empty()) {
+ origArg.replaceAllUsesWith(
+ rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
+ }
+ }
+
+ delete origBlock;
+ origBlock = nullptr;
+}
+
void BlockTypeConversionRewrite::rollback() {
- // Undo the type conversion.
- rewriterImpl.argConverter.discardRewrites(block);
-}
-
-/// Detach any operations nested in the given operation from their parent
-/// blocks, and erase the given operation. This can be used when the nested
-/// operations are scheduled for erasure themselves, so deleting the regions of
-/// the given operation together with their content would result in double-free.
-/// This happens, for example, when rolling back op creation in the reverse
-/// order and if the nested ops were created before the parent op. This function
-/// does not need to collect nested ops recursively because it is expected to
-/// also be called for each nested op when it is about to be deleted.
-static void detachNestedAndErase(Operation *op) {
+ // Drop all uses of the new block arguments and replace uses of the new block.
+ for (int i = block->getNumArguments() - 1; i >= 0; --i)
+ block->getArgument(i).dropAllUses();
+ block->replaceAllUsesWith(origBlock);
+
+ // Move the operations back the original block, move the original block back
+ // into its original location and the delete the new block.
+ origBlock->getOperations().splice(origBlock->end(), block->getOperations());
+ block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
+ eraseBlock(block);
+}
+
+LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
+ function_ref<Operation *(Value)> findLiveUser) {
+ // Process the remapping for each of the original arguments.
+ for (auto it : llvm::enumerate(origBlock->getArguments())) {
+ // If the type of this argument changed and the argument is still live, we
+ // need to materialize a conversion.
+ BlockArgument origArg = it.value();
+ if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+ continue;
+ Operation *liveUser = findLiveUser(origArg);
+ if (!liveUser)
+ continue;
+
+ Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
+ bool isDroppedArg = replacementValue == origArg;
+ if (isDroppedArg)
+ rewriterImpl.rewriter.setInsertionPointToStart(getBlock());
+ else
+ rewriterImpl.rewriter.setInsertionPointAfterValue(replacementValue);
+ Value newArg;
+ if (converter) {
+ newArg = converter->materializeSourceConversion(
+ rewriterImpl.rewriter, origArg.getLoc(), origArg.getType(),
+ isDroppedArg ? ValueRange() : ValueRange(replacementValue));
+ assert((!newArg || newArg.getType() == origArg.getType()) &&
+ "materialization hook did not provide a value of the expected "
+ "type");
+ }
+ if (!newArg) {
+ InFlightDiagnostic diag =
+ emitError(origArg.getLoc())
+ << "failed to materialize conversion for block argument #"
+ << it.index() << " that remained live after conversion, type was "
+ << origArg.getType();
+ if (!isDroppedArg)
+ diag << ", with target type " << replacementValue.getType();
+ diag.attachNote(liveUser->getLoc())
+ << "see existing live user here: " << *liveUser;
+ return failure();
+ }
+ rewriterImpl.mapping.map(origArg, newArg);
+ }
+ return success();
+}
+
+void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
for (Region ®ion : op->getRegions()) {
for (Block &block : region.getBlocks()) {
while (!block.getOperations().empty())
@@ -1228,8 +1071,7 @@ static void detachNestedAndErase(Operation *op) {
block.dropAllDefinedValueUses();
}
}
- op->dropAllUses();
- op->erase();
+ eraseRewriter.eraseOp(op);
}
void ConversionPatternRewriterImpl::discardRewrites() {
@@ -1248,11 +1090,6 @@ void ConversionPatternRewriterImpl::applyRewrites() {
for (OpResult result : repl.first->getResults())
if (Value newValue = mapping.lookupOrNull(result, result.getType()))
result.replaceAllUsesWith(newValue);
-
- // If this operation defines any regions, drop any pending argument
- // rewrites.
- if (repl.first->getNumRegions())
- argConverter.notifyOpRemoved(repl.first);
}
// Apply all of the requested argument replacements.
@@ -1279,22 +1116,16 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// Drop all of the unresolved materialization operations created during
// conversion.
- for (auto &mat : unresolvedMaterializations) {
- mat.getOp()->dropAllUses();
- mat.getOp()->erase();
- }
+ for (auto &mat : unresolvedMaterializations)
+ eraseRewriter.eraseOp(mat.getOp());
// In a second pass, erase all of the replaced operations in reverse. This
// allows processing nested operations before their parent region is
// destroyed. Because we process in reverse order, producers may be deleted
// before their users (a pattern deleting a producer and then the consumer)
// so we first drop all uses explicitly.
- for (auto &repl : llvm::reverse(replacements)) {
- repl.first->dropAllUses();
- repl.first->erase();
- }
-
- argConverter.applyRewrites(mapping);
+ for (auto &repl : llvm::reverse(replacements))
+ eraseRewriter.eraseOp(repl.first);
// Commit all rewrites.
for (auto &rewrite : rewrites)
@@ -1307,7 +1138,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
replacements.size(), argReplacements.size(),
- rewrites.size(), ignoredOps.size());
+ rewrites.size(), ignoredOps.size(),
+ eraseRewriter.erased.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1355,6 +1187,9 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
while (!operationsWithChangedResults.empty() &&
operationsWithChangedResults.back() >= state.numReplacements)
operationsWithChangedResults.pop_back();
+
+ while (eraseRewriter.erased.size() != state.numErased)
+ eraseRewriter.erased.pop_back();
}
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
@@ -1443,18 +1278,18 @@ void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion) {
- FailureOr<Block *> result =
- conversion ? argConverter.applySignatureConversion(
- block, converter, *conversion, mapping, argReplacements)
- : argConverter.convertSignature(block, converter, mapping,
- argReplacements);
- if (failed(result))
+ if (conversion)
+ return applySignatureConversion(block, converter, *conversion);
+
+ // If a converter wasn't provided, and the block wasn't already converted,
+ // there is nothing we can do.
+ if (!converter)
return failure();
- if (Block *newBlock = *result) {
- if (newBlock != block)
- appendRewrite<BlockTypeConversionRewrite>(newBlock);
- }
- return result;
+
+ // Try to convert the signature for the block with the provided converter.
+ if (auto conversion = converter->convertBlockSignature(block))
+ return applySignatureConversion(block, converter, *conversion);
+ return failure();
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
@@ -1508,6 +1343,102 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
return success();
}
+Block *ConversionPatternRewriterImpl::applySignatureConversion(
+ Block *block, const TypeConverter *converter,
+ TypeConverter::SignatureConversion &signatureConversion) {
+ // If no arguments are being changed or added, there is nothing to do.
+ unsigned origArgCount = block->getNumArguments();
+ auto convertedTypes = signatureConversion.getConvertedTypes();
+ if (llvm::equal(block->getArgumentTypes(), convertedTypes))
+ return block;
+
+ // Split the block at the beginning to get a new block to use for the updated
+ // signature.
+ Block *newBlock = block->splitBlock(block->begin());
+ block->replaceAllUsesWith(newBlock);
+ // Unlink the block, but do not erase it yet, so that the change can be rolled
+ // back.
+ block->getParent()->getBlocks().remove(block);
+
+ // Map all new arguments to the location of the argument they originate from.
+ SmallVector<Location> newLocs(convertedTypes.size(),
+ rewriter.getUnknownLoc());
+ for (unsigned i = 0; i < origArgCount; ++i) {
+ auto inputMap = signatureConversion.getInputMapping(i);
+ if (!inputMap || inputMap->replacementValue)
+ continue;
+ Location origLoc = block->getArgument(i).getLoc();
+ for (unsigned j = 0; j < inputMap->size; ++j)
+ newLocs[inputMap->inputNo + j] = origLoc;
+ }
+
+ SmallVector<Value, 4> newArgRange(
+ newBlock->addArguments(convertedTypes, newLocs));
+ ArrayRef<Value> newArgs(newArgRange);
+
+ // Remap each of the original arguments as determined by the signature
+ // conversion.
+ SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+ argInfo.resize(origArgCount);
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(newBlock);
+ for (unsigned i = 0; i != origArgCount; ++i) {
+ auto inputMap = signatureConversion.getInputMapping(i);
+ if (!inputMap)
+ continue;
+ BlockArgument origArg = block->getArgument(i);
+
+ // If inputMap->replacementValue is not nullptr, then the argument is
+ // dropped and a replacement value is provided to be the remappedValue.
+ if (inputMap->replacementValue) {
+ assert(inputMap->size == 0 &&
+ "invalid to provide a replacement value when the argument isn't "
+ "dropped");
+ mapping.map(origArg, inputMap->replacementValue);
+ argReplacements.push_back(origArg);
+ continue;
+ }
+
+ // Otherwise, this is a 1->1+ mapping.
+ auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
+ Value newArg;
+
+ // If this is a 1->1 mapping and the types of new and replacement arguments
+ // match (i.e. it's an identity map), then the argument is mapped to its
+ // original type.
+ // FIXME: We simply pass through the replacement argument if there wasn't a
+ // converter, which isn't great as it allows implicit type conversions to
+ // appear. We should properly restructure this code to handle cases where a
+ // converter isn't provided and also to properly handle the case where an
+ // argument materialization is actually a temporary source materialization
+ // (e.g. in the case of 1->N).
+ if (replArgs.size() == 1 &&
+ (!converter || replArgs[0].getType() == origArg.getType())) {
+ newArg = replArgs.front();
+ } else {
+ Type origOutputType = origArg.getType();
+
+ // Legalize the argument output type.
+ Type outputType = origOutputType;
+ if (Type legalOutputType = converter->convertType(outputType))
+ outputType = legalOutputType;
+
+ newArg = buildUnresolvedArgumentMaterialization(
+ rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
+ converter, unresolvedMaterializations);
+ }
+
+ mapping.map(origArg, newArg);
+ argReplacements.push_back(origArg);
+ argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
+ }
+
+ appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
+ converter);
+ return newBlock;
+}
+
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -2635,8 +2566,11 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
});
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
};
- return rewriterImpl.argConverter.materializeLiveConversions(
- rewriterImpl.mapping, rewriter, findLiveUser);
+ for (auto &r : rewriterImpl.rewrites)
+ if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get()))
+ if (failed(rewrite->materializeLiveConversions(findLiveUser)))
+ return failure();
+ return success();
}
/// Replace the results of a materialization operation with the given values.
More information about the Mlir-commits
mailing list