[Mlir-commits] [mlir] [mlir][Transforms][NFC] Dialect Conversion: Store materialization metadata separately (PR #148415)
Matthias Springer
llvmlistbot at llvm.org
Sun Jul 13 01:35:36 PDT 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/148415
Store metadata about unresolved materializations in a separate data structure. This is in preparation of the One-Shot Dialect Conversion refactoring, which no longer maintains a stack of `IRRewrite` objects. Therefore, metadata about unresolved materializations can no longer be retrieved from `UnresolvedMaterializationRewrite` objects.
This commit also removes a pointer indirection and may slightly improve the performance of the existing driver.
>From 7cedbb943b125020888f6833a3c344c350f47efb Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 13 Jul 2025 08:31:32 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Dialect Conversion: Store
materialization metadata separately
---
.../Transforms/Utils/DialectConversion.cpp | 89 +++++++++----------
1 file changed, 43 insertions(+), 46 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 437dbcfea5288..42abc152981e6 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -789,38 +789,22 @@ enum MaterializationKind {
Source
};
-/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
-/// op. Unresolved materializations are erased at the end of the dialect
-/// conversion.
-class UnresolvedMaterializationRewrite : public OperationRewrite {
+/// Helper class that stores metadata about an unresolved materialization.
+class UnresolvedMaterializationInfo {
public:
- UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- UnrealizedConversionCastOp op,
- const TypeConverter *converter,
- MaterializationKind kind, Type originalType,
- ValueVector mappedValues);
-
- static bool classof(const IRRewrite *rewrite) {
- return rewrite->getKind() == Kind::UnresolvedMaterialization;
- }
-
- void rollback() override;
-
- UnrealizedConversionCastOp getOperation() const {
- return cast<UnrealizedConversionCastOp>(op);
- }
+ UnresolvedMaterializationInfo() = default;
+ UnresolvedMaterializationInfo(const TypeConverter *converter,
+ MaterializationKind kind, Type originalType)
+ : converterAndKind(converter, kind), originalType(originalType) {}
- /// Return the type converter of this materialization (which may be null).
const TypeConverter *getConverter() const {
return converterAndKind.getPointer();
}
- /// Return the kind of this materialization.
MaterializationKind getMaterializationKind() const {
return converterAndKind.getInt();
}
- /// Return the original type of the SSA value.
Type getOriginalType() const { return originalType; }
private:
@@ -832,7 +816,30 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
/// The original type of the SSA value. Only used for target
/// materializations.
Type originalType;
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations fold away or are replaced with
+/// source/target materializations at the end of the dialect conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+ UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ UnrealizedConversionCastOp op,
+ ValueVector mappedValues)
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+ mappedValues(std::move(mappedValues)) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::UnresolvedMaterialization;
+ }
+
+ void rollback() override;
+ UnrealizedConversionCastOp getOperation() const {
+ return cast<UnrealizedConversionCastOp>(op);
+ }
+
+private:
/// The values in the conversion value mapping that are being replaced by the
/// results of this unresolved materialization.
ValueVector mappedValues;
@@ -1088,9 +1095,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// by the current pattern.
SetVector<Block *> patternInsertedBlocks;
- /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
- /// to the corresponding rewrite objects.
- DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+ /// A mapping for looking up metadata of unresolved materializations.
+ DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
unresolvedMaterializations;
/// The current type converter, or nullptr if no type converter is currently
@@ -1210,18 +1216,6 @@ void CreateOperationRewrite::rollback() {
op->erase();
}
-UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
- ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
- const TypeConverter *converter, MaterializationKind kind, Type originalType,
- ValueVector mappedValues)
- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind), originalType(originalType),
- mappedValues(std::move(mappedValues)) {
- assert((!originalType || kind == MaterializationKind::Target) &&
- "original type is valid only for target materializations");
- rewriterImpl.unresolvedMaterializations[op] = this;
-}
-
void UnresolvedMaterializationRewrite::rollback() {
if (!mappedValues.empty())
rewriterImpl.mapping.erase(mappedValues);
@@ -1510,8 +1504,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
mapping.map(valuesToMap, convertOp.getResults());
if (castOp)
*castOp = convertOp;
- appendRewrite<UnresolvedMaterializationRewrite>(
- convertOp, converter, kind, originalType, std::move(valuesToMap));
+ unresolvedMaterializations[convertOp] =
+ UnresolvedMaterializationInfo(converter, kind, originalType);
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+ std::move(valuesToMap));
return convertOp.getResults();
}
@@ -2679,21 +2675,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
static LogicalResult
legalizeUnresolvedMaterialization(RewriterBase &rewriter,
- UnresolvedMaterializationRewrite *rewrite) {
- UnrealizedConversionCastOp op = rewrite->getOperation();
+ UnrealizedConversionCastOp op,
+ const UnresolvedMaterializationInfo &info) {
assert(!op.use_empty() &&
"expected that dead materializations have already been DCE'd");
Operation::operand_range inputOperands = op.getOperands();
// Try to materialize the conversion.
- if (const TypeConverter *converter = rewrite->getConverter()) {
+ if (const TypeConverter *converter = info.getConverter()) {
rewriter.setInsertionPoint(op);
SmallVector<Value> newMaterialization;
- switch (rewrite->getMaterializationKind()) {
+ switch (info.getMaterializationKind()) {
case MaterializationKind::Target:
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
- rewrite->getOriginalType());
+ info.getOriginalType());
break;
case MaterializationKind::Source:
assert(op->getNumResults() == 1 && "expected single result");
@@ -2768,7 +2764,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Gather all unresolved materializations.
SmallVector<UnrealizedConversionCastOp> allCastOps;
- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
&materializations = rewriterImpl.unresolvedMaterializations;
for (auto it : materializations)
allCastOps.push_back(it.first);
@@ -2785,7 +2781,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
auto it = materializations.find(castOp);
assert(it != materializations.end() && "inconsistent state");
- if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
+ if (failed(
+ legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
return failure();
}
}
More information about the Mlir-commits
mailing list