[Mlir-commits] [mlir] [mlir][Transforms][NFC] Dialect Conversion: Store materialization metadata separately (PR #148415)

Matthias Springer llvmlistbot at llvm.org
Sun Jul 13 01:42:26 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/148415

>From 2ce59d77287386405469d70be92fe00433b2fc1d Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 13 Jul 2025 08:31:32 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Dialect Conversion: Store
 materialization metadata separately

---
 .../Transforms/Utils/DialectConversion.cpp    | 86 +++++++++----------
 1 file changed, 43 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 437dbcfea5288..07e90705bd79b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -789,26 +789,13 @@ enum MaterializationKind {
   Source
 };
 
-/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
-/// op. Unresolved materializations are erased at the end of the dialect
-/// conversion.
-class UnresolvedMaterializationRewrite : public OperationRewrite {
+/// Helper class that stores metadata about an unresolved materialization.
+class UnresolvedMaterializationInfo {
 public:
-  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                                   UnrealizedConversionCastOp op,
-                                   const TypeConverter *converter,
-                                   MaterializationKind kind, Type originalType,
-                                   ValueVector mappedValues);
-
-  static bool classof(const IRRewrite *rewrite) {
-    return rewrite->getKind() == Kind::UnresolvedMaterialization;
-  }
-
-  void rollback() override;
-
-  UnrealizedConversionCastOp getOperation() const {
-    return cast<UnrealizedConversionCastOp>(op);
-  }
+  UnresolvedMaterializationInfo() = default;
+  UnresolvedMaterializationInfo(const TypeConverter *converter,
+                                MaterializationKind kind, Type originalType)
+      : converterAndKind(converter, kind), originalType(originalType) {}
 
   /// Return the type converter of this materialization (which may be null).
   const TypeConverter *getConverter() const {
@@ -832,7 +819,30 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   /// The original type of the SSA value. Only used for target
   /// materializations.
   Type originalType;
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations fold away or are replaced with
+/// source/target materializations at the end of the dialect conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                                   UnrealizedConversionCastOp op,
+                                   ValueVector mappedValues)
+      : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+        mappedValues(std::move(mappedValues)) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::UnresolvedMaterialization;
+  }
+
+  void rollback() override;
 
+  UnrealizedConversionCastOp getOperation() const {
+    return cast<UnrealizedConversionCastOp>(op);
+  }
+
+private:
   /// The values in the conversion value mapping that are being replaced by the
   /// results of this unresolved materialization.
   ValueVector mappedValues;
@@ -1088,9 +1098,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// by the current pattern.
   SetVector<Block *> patternInsertedBlocks;
 
-  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
-  /// to the corresponding rewrite objects.
-  DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+  /// A mapping for looking up metadata of unresolved materializations.
+  DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
       unresolvedMaterializations;
 
   /// The current type converter, or nullptr if no type converter is currently
@@ -1210,18 +1219,6 @@ void CreateOperationRewrite::rollback() {
   op->erase();
 }
 
-UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
-    ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
-    const TypeConverter *converter, MaterializationKind kind, Type originalType,
-    ValueVector mappedValues)
-    : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-      converterAndKind(converter, kind), originalType(originalType),
-      mappedValues(std::move(mappedValues)) {
-  assert((!originalType || kind == MaterializationKind::Target) &&
-         "original type is valid only for target materializations");
-  rewriterImpl.unresolvedMaterializations[op] = this;
-}
-
 void UnresolvedMaterializationRewrite::rollback() {
   if (!mappedValues.empty())
     rewriterImpl.mapping.erase(mappedValues);
@@ -1510,8 +1507,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     mapping.map(valuesToMap, convertOp.getResults());
   if (castOp)
     *castOp = convertOp;
-  appendRewrite<UnresolvedMaterializationRewrite>(
-      convertOp, converter, kind, originalType, std::move(valuesToMap));
+  unresolvedMaterializations[convertOp] =
+      UnresolvedMaterializationInfo(converter, kind, originalType);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+                                                  std::move(valuesToMap));
   return convertOp.getResults();
 }
 
@@ -2679,21 +2678,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
 
 static LogicalResult
 legalizeUnresolvedMaterialization(RewriterBase &rewriter,
-                                  UnresolvedMaterializationRewrite *rewrite) {
-  UnrealizedConversionCastOp op = rewrite->getOperation();
+                                  UnrealizedConversionCastOp op,
+                                  const UnresolvedMaterializationInfo &info) {
   assert(!op.use_empty() &&
          "expected that dead materializations have already been DCE'd");
   Operation::operand_range inputOperands = op.getOperands();
 
   // Try to materialize the conversion.
-  if (const TypeConverter *converter = rewrite->getConverter()) {
+  if (const TypeConverter *converter = info.getConverter()) {
     rewriter.setInsertionPoint(op);
     SmallVector<Value> newMaterialization;
-    switch (rewrite->getMaterializationKind()) {
+    switch (info.getMaterializationKind()) {
     case MaterializationKind::Target:
       newMaterialization = converter->materializeTargetConversion(
           rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
-          rewrite->getOriginalType());
+          info.getOriginalType());
       break;
     case MaterializationKind::Source:
       assert(op->getNumResults() == 1 && "expected single result");
@@ -2768,7 +2767,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
 
   // Gather all unresolved materializations.
   SmallVector<UnrealizedConversionCastOp> allCastOps;
-  const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+  const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
       &materializations = rewriterImpl.unresolvedMaterializations;
   for (auto it : materializations)
     allCastOps.push_back(it.first);
@@ -2785,7 +2784,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
     for (UnrealizedConversionCastOp castOp : remainingCastOps) {
       auto it = materializations.find(castOp);
       assert(it != materializations.end() && "inconsistent state");
-      if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
+      if (failed(
+              legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
         return failure();
     }
   }



More information about the Mlir-commits mailing list