[Mlir-commits] [mlir] [mlir][Transforms] Add option to build materializations immediately (PR #156030)

Matthias Springer llvmlistbot at llvm.org
Fri Aug 29 07:18:56 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/156030

None

>From 6477f75393aedd50463c89f883407a2060c7f490 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 29 Aug 2025 14:13:52 +0000
Subject: [PATCH] [mlir][Transforms] Add option to build materializations
 immediately

---
 .../mlir/Transforms/DialectConversion.h       |  28 ++--
 .../Transforms/Utils/DialectConversion.cpp    | 122 ++++++++++++------
 mlir/test/Transforms/test-legalizer.mlir      |   2 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  17 ++-
 4 files changed, 113 insertions(+), 56 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 14dfbf18836c6..e13937213d2f4 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1346,15 +1346,25 @@ struct ConversionConfig {
   // represented at the moment.
   RewriterBase::Listener *listener = nullptr;
 
-  /// If set to "true", the dialect conversion attempts to build source/target
-  /// materializations through the type converter API in lieu of
-  /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
-  /// at least one materialization could not be built.
-  ///
-  /// If set to "false", the dialect conversion does not build any custom
-  /// materializations and instead inserts "builtin.unrealized_conversion_cast"
-  /// ops to ensure that the resulting IR is valid.
-  bool buildMaterializations = true;
+  enum class MaterializationMode {
+    /// Never build materializations with the type converter. Instead, insert
+    /// "builtin.unrealized_conversion_cast" ops to ensure that the types of
+    /// the resulting IR are valid.
+    Never,
+    /// Build materializations with the type converter immediately. (If that
+    /// fails, insert "builtin.unrealized_conversion_cast" ops.)
+    Immediate,
+    /// Insert "builtin.unrealized_conversion_cast" ops first. At the end of
+    /// the conversion, replace them with materializations built with the type
+    /// converter. This can result in fewer materializations because
+    /// "builtin.unrealized_conversion_cast" ops that cancel each other out are
+    /// folded away.
+    Delayed
+  };
+
+  /// This option controls whether and when materializations should be built
+  /// through the type converter API.
+  MaterializationMode buildMaterializations = MaterializationMode::Delayed;
 
   /// If set to "true", pattern rollback is allowed. The conversion driver
   /// rolls back IR modifications in the following situations.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c0685f54731d5..18adee7b19175 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -810,7 +810,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   void rollback() override;
 
   UnrealizedConversionCastOp getOperation() const {
-    return cast<UnrealizedConversionCastOp>(op);
+    return dyn_cast_or_null<UnrealizedConversionCastOp>(op);
   }
 
 private:
@@ -973,10 +973,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// "out of thin air" appear like unresolved materializations because they are
   /// unrealized_conversion_cast ops. However, they must be treated like
   /// regular value replacements.)
-  ValueRange buildUnresolvedMaterialization(
-      MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-      ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
-      Type originalType, const TypeConverter *converter,
+  ValueVector buildUnresolvedMaterialization(
+      OpBuilder &builder, MaterializationKind kind, OpBuilder::InsertPoint ip,
+      Location loc, ValueVector valuesToMap, ValueRange inputs,
+      TypeRange outputTypes, Type originalType, const TypeConverter *converter,
       bool isPureTypeConversion = true);
 
   /// Find a replacement value for the given SSA value in the conversion value
@@ -984,7 +984,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// value. If there is no replacement value with the correct type, find the
   /// latest replacement value (regardless of the type) and build a source
   /// materialization.
-  Value findOrBuildReplacementValue(Value value,
+  Value findOrBuildReplacementValue(OpBuilder &builder, Value value,
                                     const TypeConverter *converter);
 
   //===--------------------------------------------------------------------===//
@@ -1188,7 +1188,8 @@ static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
 }
 
 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
-  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+  Value repl =
+      rewriterImpl.findOrBuildReplacementValue(rewriter, arg, converter);
   if (!repl)
     return;
   performReplaceBlockArg(rewriter, arg, repl);
@@ -1203,7 +1204,8 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   // Compute replacement values.
   SmallVector<Value> replacements =
       llvm::map_to_vector(op->getResults(), [&](OpResult result) {
-        return rewriterImpl.findOrBuildReplacementValue(result, converter);
+        return rewriterImpl.findOrBuildReplacementValue(rewriter, result,
+                                                        converter);
       });
 
   // Notify the listener that the operation is about to be replaced.
@@ -1251,8 +1253,10 @@ void CreateOperationRewrite::rollback() {
 void UnresolvedMaterializationRewrite::rollback() {
   if (!mappedValues.empty())
     rewriterImpl.mapping.erase(mappedValues);
-  rewriterImpl.unresolvedMaterializations.erase(getOperation());
-  op->erase();
+  if (getOperation()) {
+    rewriterImpl.unresolvedMaterializations.erase(getOperation());
+    op->erase();
+  }
 }
 
 void ConversionPatternRewriterImpl::applyRewrites() {
@@ -1458,11 +1462,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     // Create a materialization for the most recently mapped values.
     repl = lookupOrDefault(operand, /*desiredTypes=*/{},
                            /*skipPureTypeConversions=*/true);
-    ValueRange castValues = buildUnresolvedMaterialization(
-        MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
+    ValueVector castValues = buildUnresolvedMaterialization(
+        rewriter, MaterializationKind::Target, computeInsertPoint(repl),
+        operandLoc,
         /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
         /*originalType=*/origType, currentTypeConverter);
-    remapped.push_back(castValues);
+    remapped.push_back(std::move(castValues));
   }
   return success();
 }
@@ -1577,7 +1582,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       // Materialize a replacement value "out of thin air".
       Value mat =
           buildUnresolvedMaterialization(
-              MaterializationKind::Source,
+              rewriter, MaterializationKind::Source,
               OpBuilder::InsertPoint(newBlock, newBlock->begin()),
               origArg.getLoc(),
               /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
@@ -1620,46 +1625,76 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 
 /// Build an unresolved materialization operation given an output type and set
 /// of input operands.
-ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
-    MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
-    Type originalType, const TypeConverter *converter,
+ValueVector ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+    OpBuilder &builder, MaterializationKind kind, OpBuilder::InsertPoint ip,
+    Location loc, ValueVector valuesToMap, ValueRange inputs,
+    TypeRange outputTypes, Type originalType, const TypeConverter *converter,
     bool isPureTypeConversion) {
   assert((!originalType || kind == MaterializationKind::Target) &&
          "original type is valid only for target materializations");
   assert(TypeRange(inputs) != outputTypes &&
          "materialization is not necessary");
+  ValueVector results;
+
+  // Build materializations with the type converter if requested.
+  if (converter && config.buildMaterializations ==
+                       ConversionConfig::MaterializationMode::Immediate) {
+    OpBuilder::InsertionGuard g(builder);
+    builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
+    if (kind == MaterializationKind::Source) {
+      assert(outputTypes.size() == 1 && "expected single output type");
+      Value mat = converter->materializeSourceConversion(
+          builder, loc, outputTypes.front(), inputs);
+      if (mat)
+        results.push_back(mat);
+    } else {
+      assert(kind == MaterializationKind::Target &&
+             "expected source or target materialization");
+      SmallVector<Value> mat = converter->materializeTargetConversion(
+          builder, loc, outputTypes, inputs);
+      if (!mat.empty())
+        llvm::append_range(results, mat);
+    }
+  }
 
-  // Create an unresolved materialization. We use a new OpBuilder to avoid
-  // tracking the materialization like we do for other operations.
-  OpBuilder builder(outputTypes.front().getContext());
-  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
-  UnrealizedConversionCastOp convertOp =
-      UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
-  if (config.attachDebugMaterializationKind) {
-    StringRef kindStr =
-        kind == MaterializationKind::Source ? "source" : "target";
-    convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
+  // Otherwise, or if the type converter failed to build a materialization,
+  // insert an unrealized_conversion_cast op and try to resolve it later.
+  UnrealizedConversionCastOp castOp;
+  if (results.empty()) {
+    assert(results.empty() && "expected no results");
+    // Create an unresolved materialization. We use a new OpBuilder to avoid
+    // tracking the materialization like we do for other operations.
+    OpBuilder builder(outputTypes.front().getContext());
+    builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
+    castOp =
+        UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
+    unresolvedMaterializations[castOp] =
+        UnresolvedMaterializationInfo(converter, kind, originalType);
+    if (config.attachDebugMaterializationKind) {
+      StringRef kindStr =
+          kind == MaterializationKind::Source ? "source" : "target";
+      castOp->setAttr("__kind__", builder.getStringAttr(kindStr));
+    }
+    if (isPureTypeConversion)
+      castOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
+    llvm::append_range(results, castOp.getResults());
   }
-  if (isPureTypeConversion)
-    convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
 
   // Register the materialization.
-  unresolvedMaterializations[convertOp] =
-      UnresolvedMaterializationInfo(converter, kind, originalType);
   if (config.allowPatternRollback) {
     if (!valuesToMap.empty())
-      mapping.map(valuesToMap, convertOp.getResults());
-    appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+      mapping.map(valuesToMap, results);
+    appendRewrite<UnresolvedMaterializationRewrite>(castOp,
                                                     std::move(valuesToMap));
   } else {
-    patternMaterializations.insert(convertOp);
+    if (castOp)
+      patternMaterializations.insert(castOp);
   }
-  return convertOp.getResults();
+  return results;
 }
 
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
-    Value value, const TypeConverter *converter) {
+    OpBuilder &builder, Value value, const TypeConverter *converter) {
   assert(config.allowPatternRollback &&
          "this code path is valid only in rollback mode");
 
@@ -1700,7 +1735,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
   // materialization must be valid for all future users that may be created
   // later in the conversion process.
   Value castValue =
-      buildUnresolvedMaterialization(MaterializationKind::Source,
+      buildUnresolvedMaterialization(rewriter, MaterializationKind::Source,
                                      computeInsertPoint(repl), value.getLoc(),
                                      /*valuesToMap=*/repl, /*inputs=*/repl,
                                      /*outputTypes=*/value.getType(),
@@ -1779,7 +1814,8 @@ getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange,
       // The replaced value is dropped. Materialize a replacement value "out of
       // thin air".
       Value srcMat = impl.buildUnresolvedMaterialization(
-          MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
+          impl.rewriter, MaterializationKind::Source, computeInsertPoint(from),
+          from.getLoc(),
           /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
           /*outputTypes=*/from.getType(), /*originalType=*/Type(),
           converter)[0];
@@ -1799,7 +1835,8 @@ getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange,
     // materializations if possible. This would require an extension of the
     // `lookupOrDefault` API.
     Value srcMat = impl.buildUnresolvedMaterialization(
-        MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
+        impl.rewriter, MaterializationKind::Source, computeInsertPoint(to),
+        from.getLoc(),
         /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
         /*originalType=*/Type(), converter)[0];
     repls.push_back(srcMat);
@@ -1855,7 +1892,7 @@ void ConversionPatternRewriterImpl::replaceOp(
       // This result was dropped and no replacement value was provided.
       // Materialize a replacement value "out of thin air".
       buildUnresolvedMaterialization(
-          MaterializationKind::Source, computeInsertPoint(result),
+          rewriter, MaterializationKind::Source, computeInsertPoint(result),
           result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
           /*outputTypes=*/result.getType(), /*originalType=*/Type(),
           currentTypeConverter, /*isPureTypeConversion=*/false);
@@ -3234,7 +3271,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
     castOp->removeAttr(kPureTypeConversionMarker);
 
   // Try to legalize all unresolved materializations.
-  if (rewriter.getConfig().buildMaterializations) {
+  if (rewriter.getConfig().buildMaterializations !=
+      ConversionConfig::MaterializationMode::Never) {
     // Use a new rewriter, so the modifications are not tracked for rollback
     // purposes etc.
     IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 3fa42ff6b2757..5ec12be83ebca 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s
 // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER
 // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=never attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND
 
 // CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "B"
 // CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "B"
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 95f381ec471d6..d6126de5dec7e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1660,11 +1660,20 @@ struct TestLegalizePatternDriver
       llvm::cl::desc(
           "Attach materialization kind to unrealized_conversion_cast ops"),
       llvm::cl::init(false)};
-  Option<bool> buildMaterializations{
+  Option<ConversionConfig::MaterializationMode> buildMaterializations{
       *this, "build-materializations",
-      llvm::cl::desc(
-          "If set to 'false', leave unrealized_conversion_cast ops in place"),
-      llvm::cl::init(true)};
+      llvm::cl::desc("When to build unresovled materializations."),
+      llvm::cl::init(ConversionConfig::MaterializationMode::Delayed),
+      llvm::cl::values(
+          clEnumValN(ConversionConfig::MaterializationMode::Never, "never",
+                     "Never build materialization with the type converter."),
+          clEnumValN(ConversionConfig::MaterializationMode::Delayed, "delayed",
+                     "Build materializations with the type converter at the "
+                     "end of the conversion."),
+          clEnumValN(ConversionConfig::MaterializationMode::Immediate,
+                     "immediate",
+                     "Build materializations with the type converter "
+                     "immediately."))};
 };
 } // namespace
 



More information about the Mlir-commits mailing list