[Mlir-commits] [mlir] [mlir][Transforms][NFC] Turn unresolved materializations into `IRRewrite`s (PR #81761)

Matthias Springer llvmlistbot at llvm.org
Fri Feb 23 01:04:28 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81761

>From 658d82839a7c85e8e5a6863eb980082c5c27b3a1 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 19 Feb 2024 10:29:29 +0000
Subject: [PATCH] [WIP] UnresolvedMaterialization

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 .../Transforms/Utils/DialectConversion.cpp    | 369 +++++++++---------
 1 file changed, 176 insertions(+), 193 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 704597148dfac0..635a2cb00f3888 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,15 +152,11 @@ namespace {
 /// This class contains a snapshot of the current conversion rewriter state.
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
-  RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
-                unsigned numIgnoredOperations, unsigned numErased)
-      : numUnresolvedMaterializations(numUnresolvedMaterializations),
-        numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
+  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
+                unsigned numErased)
+      : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
         numErased(numErased) {}
 
-  /// The current number of unresolved materializations.
-  unsigned numUnresolvedMaterializations;
-
   /// The current number of rewrites performed.
   unsigned numRewrites;
 
@@ -171,109 +167,10 @@ struct RewriterState {
   unsigned numErased;
 };
 
-//===----------------------------------------------------------------------===//
-// UnresolvedMaterialization
-
-/// This class represents an unresolved materialization, i.e. a materialization
-/// that was inserted during conversion that needs to be legalized at the end of
-/// the conversion process.
-class UnresolvedMaterialization {
-public:
-  /// The type of materialization.
-  enum Kind {
-    /// This materialization materializes a conversion for an illegal block
-    /// argument type, to a legal one.
-    Argument,
-
-    /// This materialization materializes a conversion from an illegal type to a
-    /// legal one.
-    Target
-  };
-
-  UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
-                            const TypeConverter *converter = nullptr,
-                            Kind kind = Target, Type origOutputType = nullptr)
-      : op(op), converterAndKind(converter, kind),
-        origOutputType(origOutputType) {}
-
-  /// Return the temporary conversion operation inserted for this
-  /// materialization.
-  UnrealizedConversionCastOp getOp() const { return op; }
-
-  /// Return the type converter of this materialization (which may be null).
-  const TypeConverter *getConverter() const {
-    return converterAndKind.getPointer();
-  }
-
-  /// Return the kind of this materialization.
-  Kind getKind() const { return converterAndKind.getInt(); }
-
-  /// Set the kind of this materialization.
-  void setKind(Kind kind) { converterAndKind.setInt(kind); }
-
-  /// Return the original illegal output type of the input values.
-  Type getOrigOutputType() const { return origOutputType; }
-
-private:
-  /// The unresolved materialization operation created during conversion.
-  UnrealizedConversionCastOp op;
-
-  /// The corresponding type converter to use when resolving this
-  /// materialization, and the kind of this materialization.
-  llvm::PointerIntPair<const TypeConverter *, 1, Kind> converterAndKind;
-
-  /// The original output type. This is only used for argument conversions.
-  Type origOutputType;
-};
-} // namespace
-
-/// Build an unresolved materialization operation given an output type and set
-/// of input operands.
-static Value buildUnresolvedMaterialization(
-    UnresolvedMaterialization::Kind kind, Block *insertBlock,
-    Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
-    Type origOutputType, const TypeConverter *converter,
-    SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
-  // Avoid materializing an unnecessary cast.
-  if (inputs.size() == 1 && inputs.front().getType() == outputType)
-    return inputs.front();
-
-  // Create an unresolved materialization. We use a new OpBuilder to avoid
-  // tracking the materialization like we do for other operations.
-  OpBuilder builder(insertBlock, insertPt);
-  auto convertOp =
-      builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
-  unresolvedMaterializations.emplace_back(convertOp, converter, kind,
-                                          origOutputType);
-  return convertOp.getResult(0);
-}
-static Value buildUnresolvedArgumentMaterialization(
-    PatternRewriter &rewriter, Location loc, ValueRange inputs,
-    Type origOutputType, Type outputType, const TypeConverter *converter,
-    SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
-  return buildUnresolvedMaterialization(
-      UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(),
-      rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
-      converter, unresolvedMaterializations);
-}
-static Value buildUnresolvedTargetMaterialization(
-    Location loc, Value input, Type outputType, const TypeConverter *converter,
-    SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
-  Block *insertBlock = input.getParentBlock();
-  Block::iterator insertPt = insertBlock->begin();
-  if (OpResult inputRes = dyn_cast<OpResult>(input))
-    insertPt = ++inputRes.getOwner()->getIterator();
-
-  return buildUnresolvedMaterialization(
-      UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
-      outputType, outputType, converter, unresolvedMaterializations);
-}
-
 //===----------------------------------------------------------------------===//
 // IR rewrites
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// An IR rewrite that can be committed (upon success) or rolled back (upon
 /// failure).
 ///
@@ -299,7 +196,8 @@ class IRRewrite {
     MoveOperation,
     ModifyOperation,
     ReplaceOperation,
-    CreateOperation
+    CreateOperation,
+    UnresolvedMaterialization
   };
 
   virtual ~IRRewrite() = default;
@@ -605,7 +503,7 @@ class OperationRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::MoveOperation &&
-           rewrite->getKind() <= Kind::CreateOperation;
+           rewrite->getKind() <= Kind::UnresolvedMaterialization;
   }
 
 protected:
@@ -752,6 +650,70 @@ class CreateOperationRewrite : public OperationRewrite {
 
   void rollback() override;
 };
+
+/// The type of materialization.
+enum MaterializationKind {
+  /// This materialization materializes a conversion for an illegal block
+  /// argument type, to a legal one.
+  Argument,
+
+  /// This materialization materializes a conversion from an illegal type to a
+  /// legal one.
+  Target
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations are erased at the end of the dialect
+/// conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+  UnresolvedMaterializationRewrite(
+      ConversionPatternRewriterImpl &rewriterImpl,
+      UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
+      MaterializationKind kind = MaterializationKind::Target,
+      Type origOutputType = nullptr)
+      : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+        converterAndKind(converter, kind), origOutputType(origOutputType) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::UnresolvedMaterialization;
+  }
+
+  UnrealizedConversionCastOp getOperation() const {
+    return cast<UnrealizedConversionCastOp>(op);
+  }
+
+  void rollback() override;
+
+  void cleanup() override;
+
+  /// Return the type converter of this materialization (which may be null).
+  const TypeConverter *getConverter() const {
+    return converterAndKind.getPointer();
+  }
+
+  /// Return the kind of this materialization.
+  MaterializationKind getMaterializationKind() const {
+    return converterAndKind.getInt();
+  }
+
+  /// Set the kind of this materialization.
+  void setMaterializationKind(MaterializationKind kind) {
+    converterAndKind.setInt(kind);
+  }
+
+  /// Return the original illegal output type of the input values.
+  Type getOrigOutputType() const { return origOutputType; }
+
+private:
+  /// The corresponding type converter to use when resolving this
+  /// materialization, and the kind of this materialization.
+  llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+      converterAndKind;
+
+  /// The original output type. This is only used for argument conversions.
+  Type origOutputType;
+};
 } // namespace
 
 /// Return "true" if there is an operation rewrite that matches the specified
@@ -794,14 +756,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
       : rewriter(rewriter), eraseRewriter(rewriter.getContext()),
         notifyCallback(nullptr) {}
 
-  /// Cleanup and destroy any generated rewrite operations. This method is
-  /// invoked when the conversion process fails.
-  void discardRewrites();
-
-  /// Apply all requested operation rewrites. This method is invoked when the
-  /// conversion process succeeds.
-  void applyRewrites();
-
   //===--------------------------------------------------------------------===//
   // State Management
   //===--------------------------------------------------------------------===//
@@ -809,6 +763,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Return the current state of the rewriter.
   RewriterState getCurrentState();
 
+  /// Apply all requested operation rewrites. This method is invoked when the
+  /// conversion process succeeds.
+  void applyRewrites();
+
   /// Reset the state of the rewriter to a previously saved point.
   void resetState(RewriterState state);
 
@@ -841,17 +799,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// removes them from being considered for legalization.
   void markNestedOpsIgnored(Operation *op);
 
-  /// Detach any operations nested in the given operation from their parent
-  /// blocks, and erase the given operation. This can be used when the nested
-  /// operations are scheduled for erasure themselves, so deleting the regions
-  /// of the given operation together with their content would result in
-  /// double-free. This happens, for example, when rolling back op creation in
-  /// the reverse order and if the nested ops were created before the parent op.
-  /// This function does not need to collect nested ops recursively because it
-  /// is expected to also be called for each nested op when it is about to be
-  /// deleted.
-  void detachNestedAndErase(Operation *op);
-
   //===--------------------------------------------------------------------===//
   // Type Conversion
   //===--------------------------------------------------------------------===//
@@ -890,6 +837,28 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
       Block *block, const TypeConverter *converter,
       TypeConverter::SignatureConversion &signatureConversion);
 
+  //===--------------------------------------------------------------------===//
+  // Materializations
+  //===--------------------------------------------------------------------===//
+  /// Build an unresolved materialization operation given an output type and set
+  /// of input operands.
+  Value buildUnresolvedMaterialization(MaterializationKind kind,
+                                       Block *insertBlock,
+                                       Block::iterator insertPt, Location loc,
+                                       ValueRange inputs, Type outputType,
+                                       Type origOutputType,
+                                       const TypeConverter *converter);
+
+  Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter,
+                                               Location loc, ValueRange inputs,
+                                               Type origOutputType,
+                                               Type outputType,
+                                               const TypeConverter *converter);
+
+  Value buildUnresolvedTargetMaterialization(Location loc, Value input,
+                                             Type outputType,
+                                             const TypeConverter *converter);
+
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -969,10 +938,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
 
-  /// Ordered vector of all unresolved type conversion materializations during
-  /// conversion.
-  SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
-
   /// Ordered list of block operations (creations, splits, motions).
   SmallVector<std::unique_ptr<IRRewrite>> rewrites;
 
@@ -1162,24 +1127,15 @@ void CreateOperationRewrite::rollback() {
   eraseOp(op);
 }
 
-void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
-  for (Region &region : op->getRegions()) {
-    for (Block &block : region.getBlocks()) {
-      while (!block.getOperations().empty())
-        block.getOperations().remove(block.getOperations().begin());
-      block.dropAllDefinedValueUses();
-    }
+void UnresolvedMaterializationRewrite::rollback() {
+  if (getMaterializationKind() == MaterializationKind::Target) {
+    for (Value input : op->getOperands())
+      rewriterImpl.mapping.erase(input);
   }
-  eraseRewriter.eraseOp(op);
+  eraseOp(op);
 }
 
-void ConversionPatternRewriterImpl::discardRewrites() {
-  undoRewrites();
-
-  // Remove any newly created ops.
-  for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
-    detachNestedAndErase(materialization.getOp());
-}
+void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
 
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
@@ -1187,39 +1143,20 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrite->commit();
   for (auto &rewrite : rewrites)
     rewrite->cleanup();
-
-  // Drop all of the unresolved materialization operations created during
-  // conversion.
-  for (auto &mat : unresolvedMaterializations)
-    eraseRewriter.eraseOp(mat.getOp());
 }
 
 //===----------------------------------------------------------------------===//
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
-                       ignoredOps.size(), eraseRewriter.erased.size());
+  return RewriterState(rewrites.size(), ignoredOps.size(),
+                       eraseRewriter.erased.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   // Undo any rewrites.
   undoRewrites(state.numRewrites);
 
-  // Pop all of the newly inserted materializations.
-  while (unresolvedMaterializations.size() !=
-         state.numUnresolvedMaterializations) {
-    UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val();
-    UnrealizedConversionCastOp op = mat.getOp();
-
-    // If this was a target materialization, drop the mapping that was inserted.
-    if (mat.getKind() == UnresolvedMaterialization::Target) {
-      for (Value input : op->getOperands())
-        mapping.erase(input);
-    }
-    detachNestedAndErase(op);
-  }
-
   // Pop all of the recorded ignored operations that are no longer valid.
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
@@ -1280,8 +1217,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     if (currentTypeConverter && desiredType && newOperandType != desiredType) {
       Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
       Value castValue = buildUnresolvedTargetMaterialization(
-          operandLoc, newOperand, desiredType, currentTypeConverter,
-          unresolvedMaterializations);
+          operandLoc, newOperand, desiredType, currentTypeConverter);
       mapping.map(mapping.lookupOrDefault(newOperand), castValue);
       newOperand = castValue;
     }
@@ -1463,7 +1399,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 
       newArg = buildUnresolvedArgumentMaterialization(
           rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
-          converter, unresolvedMaterializations);
+          converter);
     }
 
     mapping.map(origArg, newArg);
@@ -1476,6 +1412,50 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   return newBlock;
 }
 
+//===----------------------------------------------------------------------===//
+// Materializations
+//===----------------------------------------------------------------------===//
+
+/// Build an unresolved materialization operation given an output type and set
+/// of input operands.
+Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+    MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
+    Location loc, ValueRange inputs, Type outputType, Type origOutputType,
+    const TypeConverter *converter) {
+  // Avoid materializing an unnecessary cast.
+  if (inputs.size() == 1 && inputs.front().getType() == outputType)
+    return inputs.front();
+
+  // Create an unresolved materialization. We use a new OpBuilder to avoid
+  // tracking the materialization like we do for other operations.
+  OpBuilder builder(insertBlock, insertPt);
+  auto convertOp =
+      builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+                                                  origOutputType);
+  return convertOp.getResult(0);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
+    PatternRewriter &rewriter, Location loc, ValueRange inputs,
+    Type origOutputType, Type outputType, const TypeConverter *converter) {
+  return buildUnresolvedMaterialization(
+      MaterializationKind::Argument, rewriter.getInsertionBlock(),
+      rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
+      converter);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
+    Location loc, Value input, Type outputType,
+    const TypeConverter *converter) {
+  Block *insertBlock = input.getParentBlock();
+  Block::iterator insertPt = insertBlock->begin();
+  if (OpResult inputRes = dyn_cast<OpResult>(input))
+    insertPt = ++inputRes.getOwner()->getIterator();
+
+  return buildUnresolvedMaterialization(MaterializationKind::Target,
+                                        insertBlock, insertPt, loc, input,
+                                        outputType, outputType, converter);
+}
+
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -2528,18 +2508,18 @@ LogicalResult OperationConverter::convertOperations(
 
   for (auto *op : toConvert)
     if (failed(convert(rewriter, op)))
-      return rewriterImpl.discardRewrites(), failure();
+      return rewriterImpl.undoRewrites(), failure();
 
   // Now that all of the operations have been converted, finalize the conversion
   // process to ensure any lingering conversion artifacts are cleaned up and
   // legalized.
   if (failed(finalize(rewriter)))
-    return rewriterImpl.discardRewrites(), failure();
+    return rewriterImpl.undoRewrites(), failure();
 
   // After a successful conversion, apply rewrites if this is not an analysis
   // conversion.
   if (mode == OpConversionMode::Analysis) {
-    rewriterImpl.discardRewrites();
+    rewriterImpl.undoRewrites();
   } else {
     rewriterImpl.applyRewrites();
   }
@@ -2645,11 +2625,12 @@ replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
 /// Compute all of the unresolved materializations that will persist beyond the
 /// conversion process, and require inserting a proper user materialization for.
 static void computeNecessaryMaterializations(
-    DenseMap<Operation *, UnresolvedMaterialization *> &materializationOps,
+    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+        &materializationOps,
     ConversionPatternRewriter &rewriter,
     ConversionPatternRewriterImpl &rewriterImpl,
     DenseMap<Value, SmallVector<Value>> &inverseMapping,
-    SetVector<UnresolvedMaterialization *> &necessaryMaterializations) {
+    SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
   auto isLive = [&](Value value) {
     auto findFn = [&](Operation *user) {
       auto matIt = materializationOps.find(user);
@@ -2684,14 +2665,17 @@ static void computeNecessaryMaterializations(
         return Value();
       };
 
-  SetVector<UnresolvedMaterialization *> worklist;
-  for (auto &mat : rewriterImpl.unresolvedMaterializations) {
-    materializationOps.try_emplace(mat.getOp(), &mat);
-    worklist.insert(&mat);
+  SetVector<UnresolvedMaterializationRewrite *> worklist;
+  for (auto &rewrite : rewriterImpl.rewrites) {
+    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
+    if (!mat)
+      continue;
+    materializationOps.try_emplace(mat->getOperation(), mat);
+    worklist.insert(mat);
   }
   while (!worklist.empty()) {
-    UnresolvedMaterialization *mat = worklist.pop_back_val();
-    UnrealizedConversionCastOp op = mat->getOp();
+    UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
+    UnrealizedConversionCastOp op = mat->getOperation();
 
     // We currently only handle target materializations here.
     assert(op->getNumResults() == 1 && "unexpected materialization type");
@@ -2733,7 +2717,7 @@ static void computeNecessaryMaterializations(
     auto isBlockArg = [](Value v) { return isa<BlockArgument>(v); };
     if (llvm::any_of(op->getOperands(), isBlockArg) ||
         llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) {
-      mat->setKind(UnresolvedMaterialization::Argument);
+      mat->setMaterializationKind(MaterializationKind::Argument);
     }
 
     // If the materialization does not have any live users, we don't need to
@@ -2743,7 +2727,7 @@ static void computeNecessaryMaterializations(
     // value replacement even if the types differ in some cases. When those
     // patterns are fixed, we can drop the argument special case here.
     bool isMaterializationLive = isLive(opResult);
-    if (mat->getKind() == UnresolvedMaterialization::Argument)
+    if (mat->getMaterializationKind() == MaterializationKind::Argument)
       isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
     if (!isMaterializationLive)
       continue;
@@ -2763,8 +2747,9 @@ static void computeNecessaryMaterializations(
 /// Legalize the given unresolved materialization. Returns success if the
 /// materialization was legalized, failure otherise.
 static LogicalResult legalizeUnresolvedMaterialization(
-    UnresolvedMaterialization &mat,
-    DenseMap<Operation *, UnresolvedMaterialization *> &materializationOps,
+    UnresolvedMaterializationRewrite &mat,
+    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+        &materializationOps,
     ConversionPatternRewriter &rewriter,
     ConversionPatternRewriterImpl &rewriterImpl,
     DenseMap<Value, SmallVector<Value>> &inverseMapping) {
@@ -2784,7 +2769,7 @@ static LogicalResult legalizeUnresolvedMaterialization(
         return Value();
       };
 
-  UnrealizedConversionCastOp op = mat.getOp();
+  UnrealizedConversionCastOp op = mat.getOperation();
   if (!rewriterImpl.ignoredOps.insert(op))
     return success();
 
@@ -2834,8 +2819,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
       rewriter.setInsertionPoint(op);
 
     Value newMaterialization;
-    switch (mat.getKind()) {
-    case UnresolvedMaterialization::Argument:
+    switch (mat.getMaterializationKind()) {
+    case MaterializationKind::Argument:
       // Try to materialize an argument conversion.
       // FIXME: The current argument materialization hook expects the original
       // output type, even though it doesn't use that as the actual output type
@@ -2852,7 +2837,7 @@ static LogicalResult legalizeUnresolvedMaterialization(
       // If an argument materialization failed, fallback to trying a target
       // materialization.
       [[fallthrough]];
-    case UnresolvedMaterialization::Target:
+    case MaterializationKind::Target:
       newMaterialization = converter->materializeTargetConversion(
           rewriter, op->getLoc(), outputType, inputOperands);
       break;
@@ -2880,14 +2865,12 @@ LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
     ConversionPatternRewriter &rewriter,
     ConversionPatternRewriterImpl &rewriterImpl,
     std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
-  if (rewriterImpl.unresolvedMaterializations.empty())
-    return success();
   inverseMapping = rewriterImpl.mapping.getInverse();
 
   // As an initial step, compute all of the inserted materializations that we
   // expect to persist beyond the conversion process.
-  DenseMap<Operation *, UnresolvedMaterialization *> materializationOps;
-  SetVector<UnresolvedMaterialization *> necessaryMaterializations;
+  DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps;
+  SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
   computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
                                    *inverseMapping, necessaryMaterializations);
 



More information about the Mlir-commits mailing list