[llvm-branch-commits] [mlir] [mlir][Transforms][NFC][WIP] Turn unresolved materializations into `IRRewrite`s (PR #81761)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Feb 19 02:33:20 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure).
This commit turns the creation of unresolved materializations (`unrealized_conversion_cast`) into `IRRewrite` objects. After this commit, all steps in `applyRewrites` and `discardRewrites` are calls to `IRRewrite::commit` and `IRRewrite::rollback`.
---
Patch is 25.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81761.diff
1 Files Affected:
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+176-195)
``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5b7ad4e7b8e281..4ef26a739e4ea1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,15 +152,11 @@ namespace {
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
- RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
- unsigned numIgnoredOperations, unsigned numErased)
- : numUnresolvedMaterializations(numUnresolvedMaterializations),
- numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
+ RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
+ unsigned numErased)
+ : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
numErased(numErased) {}
- /// The current number of unresolved materializations.
- unsigned numUnresolvedMaterializations;
-
/// The current number of rewrites performed.
unsigned numRewrites;
@@ -171,109 +167,10 @@ struct RewriterState {
unsigned numErased;
};
-//===----------------------------------------------------------------------===//
-// UnresolvedMaterialization
-
-/// This class represents an unresolved materialization, i.e. a materialization
-/// that was inserted during conversion that needs to be legalized at the end of
-/// the conversion process.
-class UnresolvedMaterialization {
-public:
- /// The type of materialization.
- enum Kind {
- /// This materialization materializes a conversion for an illegal block
- /// argument type, to a legal one.
- Argument,
-
- /// This materialization materializes a conversion from an illegal type to a
- /// legal one.
- Target
- };
-
- UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
- const TypeConverter *converter = nullptr,
- Kind kind = Target, Type origOutputType = nullptr)
- : op(op), converterAndKind(converter, kind),
- origOutputType(origOutputType) {}
-
- /// Return the temporary conversion operation inserted for this
- /// materialization.
- UnrealizedConversionCastOp getOp() const { return op; }
-
- /// Return the type converter of this materialization (which may be null).
- const TypeConverter *getConverter() const {
- return converterAndKind.getPointer();
- }
-
- /// Return the kind of this materialization.
- Kind getKind() const { return converterAndKind.getInt(); }
-
- /// Set the kind of this materialization.
- void setKind(Kind kind) { converterAndKind.setInt(kind); }
-
- /// Return the original illegal output type of the input values.
- Type getOrigOutputType() const { return origOutputType; }
-
-private:
- /// The unresolved materialization operation created during conversion.
- UnrealizedConversionCastOp op;
-
- /// The corresponding type converter to use when resolving this
- /// materialization, and the kind of this materialization.
- llvm::PointerIntPair<const TypeConverter *, 1, Kind> converterAndKind;
-
- /// The original output type. This is only used for argument conversions.
- Type origOutputType;
-};
-} // namespace
-
-/// Build an unresolved materialization operation given an output type and set
-/// of input operands.
-static Value buildUnresolvedMaterialization(
- UnresolvedMaterialization::Kind kind, Block *insertBlock,
- Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
- Type origOutputType, const TypeConverter *converter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
- // Avoid materializing an unnecessary cast.
- if (inputs.size() == 1 && inputs.front().getType() == outputType)
- return inputs.front();
-
- // Create an unresolved materialization. We use a new OpBuilder to avoid
- // tracking the materialization like we do for other operations.
- OpBuilder builder(insertBlock, insertPt);
- auto convertOp =
- builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
- unresolvedMaterializations.emplace_back(convertOp, converter, kind,
- origOutputType);
- return convertOp.getResult(0);
-}
-static Value buildUnresolvedArgumentMaterialization(
- PatternRewriter &rewriter, Location loc, ValueRange inputs,
- Type origOutputType, Type outputType, const TypeConverter *converter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
- return buildUnresolvedMaterialization(
- UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(),
- rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
- converter, unresolvedMaterializations);
-}
-static Value buildUnresolvedTargetMaterialization(
- Location loc, Value input, Type outputType, const TypeConverter *converter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
- Block *insertBlock = input.getParentBlock();
- Block::iterator insertPt = insertBlock->begin();
- if (OpResult inputRes = dyn_cast<OpResult>(input))
- insertPt = ++inputRes.getOwner()->getIterator();
-
- return buildUnresolvedMaterialization(
- UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
- outputType, outputType, converter, unresolvedMaterializations);
-}
-
//===----------------------------------------------------------------------===//
// IR rewrites
//===----------------------------------------------------------------------===//
-namespace {
/// An IR rewrite that can be committed (upon success) or rolled back (upon
/// failure).
///
@@ -295,7 +192,8 @@ class IRRewrite {
MoveOperation,
ModifyOperation,
ReplaceOperation,
- CreateOperation
+ CreateOperation,
+ UnresolvedMaterialization
};
virtual ~IRRewrite() = default;
@@ -602,7 +500,7 @@ class OperationRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::MoveOperation &&
- rewrite->getKind() <= Kind::CreateOperation;
+ rewrite->getKind() <= Kind::UnresolvedMaterialization;
}
protected:
@@ -721,6 +619,70 @@ class CreateOperationRewrite : public OperationRewrite {
void rollback() override;
};
+
+/// The type of materialization.
+enum MaterializationKind {
+ /// This materialization materializes a conversion for an illegal block
+ /// argument type, to a legal one.
+ Argument,
+
+ /// This materialization materializes a conversion from an illegal type to a
+ /// legal one.
+ Target
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations are erased at the end of the dialect
+/// conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+ UnresolvedMaterializationRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl,
+ UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
+ MaterializationKind kind = MaterializationKind::Target,
+ Type origOutputType = nullptr)
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+ converterAndKind(converter, kind), origOutputType(origOutputType) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::UnresolvedMaterialization;
+ }
+
+ UnrealizedConversionCastOp getOperation() const {
+ return cast<UnrealizedConversionCastOp>(op);
+ }
+
+ void rollback() override;
+
+ void cleanup() override;
+
+ /// Return the type converter of this materialization (which may be null).
+ const TypeConverter *getConverter() const {
+ return converterAndKind.getPointer();
+ }
+
+ /// Return the kind of this materialization.
+ MaterializationKind getMaterializationKind() const {
+ return converterAndKind.getInt();
+ }
+
+ /// Set the kind of this materialization.
+ void setMaterializationKind(MaterializationKind kind) {
+ converterAndKind.setInt(kind);
+ }
+
+ /// Return the original illegal output type of the input values.
+ Type getOrigOutputType() const { return origOutputType; }
+
+private:
+ /// The corresponding type converter to use when resolving this
+ /// materialization, and the kind of this materialization.
+ llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+ converterAndKind;
+
+ /// The original output type. This is only used for argument conversions.
+ Type origOutputType;
+};
} // namespace
/// Return "true" if there is an operation rewrite that matches the specified
@@ -763,14 +725,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
: rewriter(rewriter), eraseRewriter(rewriter.getContext()),
notifyCallback(nullptr) {}
- /// Cleanup and destroy any generated rewrite operations. This method is
- /// invoked when the conversion process fails.
- void discardRewrites();
-
- /// Apply all requested operation rewrites. This method is invoked when the
- /// conversion process succeeds.
- void applyRewrites();
-
//===--------------------------------------------------------------------===//
// State Management
//===--------------------------------------------------------------------===//
@@ -778,6 +732,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Return the current state of the rewriter.
RewriterState getCurrentState();
+ /// Apply all requested operation rewrites. This method is invoked when the
+ /// conversion process succeeds.
+ void applyRewrites();
+
/// Reset the state of the rewriter to a previously saved point.
void resetState(RewriterState state);
@@ -810,17 +768,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// removes them from being considered for legalization.
void markNestedOpsIgnored(Operation *op);
- /// Detach any operations nested in the given operation from their parent
- /// blocks, and erase the given operation. This can be used when the nested
- /// operations are scheduled for erasure themselves, so deleting the regions
- /// of the given operation together with their content would result in
- /// double-free. This happens, for example, when rolling back op creation in
- /// the reverse order and if the nested ops were created before the parent op.
- /// This function does not need to collect nested ops recursively because it
- /// is expected to also be called for each nested op when it is about to be
- /// deleted.
- void detachNestedAndErase(Operation *op);
-
//===--------------------------------------------------------------------===//
// Type Conversion
//===--------------------------------------------------------------------===//
@@ -859,6 +806,28 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion);
+ //===--------------------------------------------------------------------===//
+ // Materializations
+ //===--------------------------------------------------------------------===//
+ /// Build an unresolved materialization operation given an output type and set
+ /// of input operands.
+ Value buildUnresolvedMaterialization(MaterializationKind kind,
+ Block *insertBlock,
+ Block::iterator insertPt, Location loc,
+ ValueRange inputs, Type outputType,
+ Type origOutputType,
+ const TypeConverter *converter);
+
+ Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter,
+ Location loc, ValueRange inputs,
+ Type origOutputType,
+ Type outputType,
+ const TypeConverter *converter);
+
+ Value buildUnresolvedTargetMaterialization(Location loc, Value input,
+ Type outputType,
+ const TypeConverter *converter);
+
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -938,10 +907,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// replacing a value with one of a different type.
ConversionValueMapping mapping;
- /// Ordered vector of all unresolved type conversion materializations during
- /// conversion.
- SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
-
/// Ordered list of block operations (creations, splits, motions).
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
@@ -1129,26 +1094,15 @@ void CreateOperationRewrite::rollback() {
eraseOp(op);
}
-void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
- // if (erasedIR.erasedOps.contains(op)) return;
-
- for (Region ®ion : op->getRegions()) {
- for (Block &block : region.getBlocks()) {
- while (!block.getOperations().empty())
- block.getOperations().remove(block.getOperations().begin());
- block.dropAllDefinedValueUses();
- }
+void UnresolvedMaterializationRewrite::rollback() {
+ if (getMaterializationKind() == MaterializationKind::Target) {
+ for (Value input : op->getOperands())
+ rewriterImpl.mapping.erase(input);
}
- eraseRewriter.eraseOp(op);
+ eraseOp(op);
}
-void ConversionPatternRewriterImpl::discardRewrites() {
- undoRewrites();
-
- // Remove any newly created ops.
- for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
- detachNestedAndErase(materialization.getOp());
-}
+void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
void ConversionPatternRewriterImpl::applyRewrites() {
// Commit all rewrites.
@@ -1156,39 +1110,20 @@ void ConversionPatternRewriterImpl::applyRewrites() {
rewrite->commit();
for (auto &rewrite : rewrites)
rewrite->cleanup();
-
- // Drop all of the unresolved materialization operations created during
- // conversion.
- for (auto &mat : unresolvedMaterializations)
- eraseRewriter.eraseOp(mat.getOp());
}
//===----------------------------------------------------------------------===//
// State Management
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
- return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
- ignoredOps.size(), eraseRewriter.erased.size());
+ return RewriterState(rewrites.size(), ignoredOps.size(),
+ eraseRewriter.erased.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
// Undo any rewrites.
undoRewrites(state.numRewrites);
- // Pop all of the newly inserted materializations.
- while (unresolvedMaterializations.size() !=
- state.numUnresolvedMaterializations) {
- UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val();
- UnrealizedConversionCastOp op = mat.getOp();
-
- // If this was a target materialization, drop the mapping that was inserted.
- if (mat.getKind() == UnresolvedMaterialization::Target) {
- for (Value input : op->getOperands())
- mapping.erase(input);
- }
- detachNestedAndErase(op);
- }
-
// Pop all of the recorded ignored operations that are no longer valid.
while (ignoredOps.size() != state.numIgnoredOperations)
ignoredOps.pop_back();
@@ -1249,8 +1184,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
Value castValue = buildUnresolvedTargetMaterialization(
- operandLoc, newOperand, desiredType, currentTypeConverter,
- unresolvedMaterializations);
+ operandLoc, newOperand, desiredType, currentTypeConverter);
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
newOperand = castValue;
}
@@ -1432,7 +1366,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
newArg = buildUnresolvedArgumentMaterialization(
rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
- converter, unresolvedMaterializations);
+ converter);
}
mapping.map(origArg, newArg);
@@ -1445,6 +1379,50 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
return newBlock;
}
+//===----------------------------------------------------------------------===//
+// Materializations
+//===----------------------------------------------------------------------===//
+
+/// Build an unresolved materialization operation given an output type and set
+/// of input operands.
+Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+ MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
+ Location loc, ValueRange inputs, Type outputType, Type origOutputType,
+ const TypeConverter *converter) {
+ // Avoid materializing an unnecessary cast.
+ if (inputs.size() == 1 && inputs.front().getType() == outputType)
+ return inputs.front();
+
+ // Create an unresolved materialization. We use a new OpBuilder to avoid
+ // tracking the materialization like we do for other operations.
+ OpBuilder builder(insertBlock, insertPt);
+ auto convertOp =
+ builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+ origOutputType);
+ return convertOp.getResult(0);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
+ PatternRewriter &rewriter, Location loc, ValueRange inputs,
+ Type origOutputType, Type outputType, const TypeConverter *converter) {
+ return buildUnresolvedMaterialization(
+ MaterializationKind::Argument, rewriter.getInsertionBlock(),
+ rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
+ converter);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
+ Location loc, Value input, Type outputType,
+ const TypeConverter *converter) {
+ Block *insertBlock = input.getParentBlock();
+ Block::iterator insertPt = insertBlock->begin();
+ if (OpResult inputRes = dyn_cast<OpResult>(input))
+ insertPt = ++inputRes.getOwner()->getIterator();
+
+ return buildUnresolvedMaterialization(MaterializationKind::Target,
+ insertBlock, insertPt, loc, input,
+ outputType, outputType, converter);
+}
+
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -2497,18 +2475,18 @@ LogicalResult OperationConverter::convertOperations(
for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
- return rewriterImpl.discardRewrites(), failure();
+ return rewriterImpl.undoRewrites(), failure();
// Now that all of the operations have been converted, finalize the conversion
// process to ensure any lingering conversion artifacts are cleaned up and
// legalized.
if (failed(finalize(rewriter)))
- return rewriterImpl.discardRewrites(), failure();
+ return rewriterImpl.undoRewrites(), failure();
// After a successful conversion, apply rewrites if this is not an analysis
// conversion.
if (mode == OpConversionMode::Analysis) {
- rewriterImpl.discardRewrites();
+ rewriterImpl.undoRewrites();
} else {
rewriterImpl.applyRewrites();
}
@@ -2613,11 +2591,12 @@ replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
/// Compute all of the unresolved materializations that will persist beyond the
/// conversion process, and require inserting a proper user materialization for.
static void computeNecessaryMaterializations(
- DenseMap<Operation *, UnresolvedMaterialization *> &materializati...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/81761
More information about the llvm-branch-commits
mailing list