[Mlir-commits] [mlir] [mlir][Transforms][WIP] Dialect conversion: Make materializations optional (PR #104668)

Matthias Springer llvmlistbot at llvm.org
Sat Aug 17 02:40:50 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/104668

Build all source/target/argument materializations after the conversion has succeeded. Provide a new configuration option for users to opt out of all automatic materializations. In that case, the resulting IR will have `builtin.unrealized_conversion_cast` ops.

Draft only, do not merge.


>From 23c0a47fd8e98689d66776c7b102ff0d2ccd91c5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 17 Aug 2024 11:38:40 +0200
Subject: [PATCH] [mlir][Transforms][WIP] Dialect conversion: Make
 materializations optional

Build all source/target/argument materializations after the conversion has succeeded. Provide a new configuration option for users to opt out of all automatic materializations. In that case, the resulting IR will have `builtin.unrealized_conversion_cast` ops.
---
 .../Transforms/Utils/DialectConversion.cpp    | 474 +++++++-----------
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir |   5 +-
 .../Transforms/finalizing-bufferize.mlir      |   1 +
 .../test-legalize-type-conversion.mlir        |   6 +-
 4 files changed, 179 insertions(+), 307 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8a4c7463a69a95..56eec3e8b00c16 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -686,6 +686,16 @@ enum MaterializationKind {
   Source
 };
 
+struct UnresolvedMaterializationMetadata {
+  UnresolvedMaterializationMetadata() = default;
+  UnresolvedMaterializationMetadata(const TypeConverter *converter,
+                                    MaterializationKind kind)
+      : converter(converter), kind(kind) {}
+
+  const TypeConverter *converter;
+  MaterializationKind kind;
+};
+
 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
 /// op. Unresolved materializations are erased at the end of the dialect
 /// conversion.
@@ -696,7 +706,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
       UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
       MaterializationKind kind = MaterializationKind::Target)
       : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-        converterAndKind(converter, kind) {}
+        metadata(converter, kind) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,23 +718,12 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 
   void rollback() override;
 
-  void cleanup(RewriterBase &rewriter) override;
-
-  /// Return the type converter of this materialization (which may be null).
-  const TypeConverter *getConverter() const {
-    return converterAndKind.getPointer();
-  }
-
-  /// Return the kind of this materialization.
-  MaterializationKind getMaterializationKind() const {
-    return converterAndKind.getInt();
-  }
+  UnresolvedMaterializationMetadata getMetadata() const { return metadata; }
 
 private:
   /// The corresponding type converter to use when resolving this
   /// materialization, and the kind of this materialization.
-  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
-      converterAndKind;
+  UnresolvedMaterializationMetadata metadata;
 };
 } // namespace
 
@@ -834,6 +833,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   //===--------------------------------------------------------------------===//
   // Materializations
   //===--------------------------------------------------------------------===//
+
   /// Build an unresolved materialization operation given an output type and set
   /// of input operands.
   Value buildUnresolvedMaterialization(MaterializationKind kind,
@@ -841,6 +841,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                        ValueRange inputs, Type outputType,
                                        const TypeConverter *converter);
 
+  SmallVector<UnrealizedConversionCastOp> getUnresolvedMaterializations();
+
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -877,8 +879,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
   struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
   public:
-    SingleEraseRewriter(MLIRContext *context)
-        : RewriterBase(context, /*listener=*/this) {}
+    SingleEraseRewriter(MLIRContext *ctx)
+        : RewriterBase(ctx, /*listener=*/this) {}
 
     /// Erase the given op (unless it was already erased).
     void eraseOp(Operation *op) override {
@@ -930,6 +932,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// to modify/access them is invalid rewriter API usage.
   SetVector<Operation *> replacedOps;
 
+  SmallVector<Operation *> unresolvedMaterializations;
+  DenseSet<Operation *> erasedUnresolvedMaterializations;
+  DenseMap<Operation *, UnresolvedMaterializationMetadata>
+      unresolvedMaterializationMetadata;
+
   /// The current type converter, or nullptr if no type converter is currently
   /// active.
   const TypeConverter *currentTypeConverter = nullptr;
@@ -1051,17 +1058,17 @@ void CreateOperationRewrite::rollback() {
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
-  if (getMaterializationKind() == MaterializationKind::Target) {
+  auto it = llvm::find(rewriterImpl.unresolvedMaterializations, op);
+  assert(it != rewriterImpl.unresolvedMaterializations.end() &&
+         "inconsistent state");
+  rewriterImpl.unresolvedMaterializations.erase(it);
+  if (metadata.kind == MaterializationKind::Target) {
     for (Value input : op->getOperands())
       rewriterImpl.mapping.erase(input);
   }
   op->erase();
 }
 
-void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
-  rewriter.eraseOp(op);
-}
-
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
   IRRewriter rewriter(context, config.listener);
@@ -1072,6 +1079,9 @@ void ConversionPatternRewriterImpl::applyRewrites() {
   SingleEraseRewriter eraseRewriter(context);
   for (auto &rewrite : rewrites)
     rewrite->cleanup(eraseRewriter);
+
+  for (void *op : eraseRewriter.erased)
+    erasedUnresolvedMaterializations.insert(reinterpret_cast<Operation *>(op));
 }
 
 //===----------------------------------------------------------------------===//
@@ -1342,10 +1352,22 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+  unresolvedMaterializations.push_back(convertOp);
+  unresolvedMaterializationMetadata[convertOp] =
+      UnresolvedMaterializationMetadata(converter, kind);
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
   return convertOp.getResult(0);
 }
 
+SmallVector<UnrealizedConversionCastOp>
+ConversionPatternRewriterImpl::getUnresolvedMaterializations() {
+  SmallVector<UnrealizedConversionCastOp> result;
+  for (Operation *op : unresolvedMaterializations)
+    if (!erasedUnresolvedMaterializations.contains(op))
+      result.push_back(cast<UnrealizedConversionCastOp>(op));
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -2354,12 +2376,6 @@ struct OperationConverter {
       ConversionPatternRewriterImpl &rewriterImpl,
       DenseMap<Value, SmallVector<Value>> &inverseMapping);
 
-  /// Legalize any unresolved type materializations.
-  LogicalResult legalizeUnresolvedMaterializations(
-      ConversionPatternRewriter &rewriter,
-      ConversionPatternRewriterImpl &rewriterImpl,
-      DenseMap<Value, SmallVector<Value>> &inverseMapping);
-
   /// Legalize an operation result that was marked as "erased".
   LogicalResult
   legalizeErasedResult(Operation *op, OpResult result,
@@ -2406,6 +2422,124 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
   return success();
 }
 
+static SmallVector<UnrealizedConversionCastOp>
+reconcileUnrealizedCasts(ArrayRef<UnrealizedConversionCastOp> castOps) {
+  SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
+                                                 castOps.end());
+  DenseSet<Operation *> erasedOps;
+
+  // Helper function that adds all operands to the worklist that are an
+  // unrealized_conversion_cast op result.
+  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+    for (Value v : castOp.getInputs())
+      if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+        worklist.insert(inputCastOp);
+  };
+
+  // Helper function that return the unrealized_conversion_cast op that
+  // defines all inputs of the given op (in the same order). Return "nullptr"
+  // if there is no such op.
+  auto getInputCast =
+      [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+    if (castOp.getInputs().empty())
+      return {};
+    auto inputCastOp =
+        castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
+    if (!inputCastOp)
+      return {};
+    if (inputCastOp.getOutputs() != castOp.getInputs())
+      return {};
+    return inputCastOp;
+  };
+
+  // Process ops in the worklist bottom-to-top.
+  while (!worklist.empty()) {
+    UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+    if (castOp->use_empty()) {
+      // DCE: If the op has no users, erase it. Add the operands to the
+      // worklist to find additional DCE opportunities.
+      enqueueOperands(castOp);
+      erasedOps.insert(castOp.getOperation());
+      castOp->erase();
+      continue;
+    }
+
+    // Traverse the chain of input cast ops to see if an op with the same
+    // input types can be found.
+    UnrealizedConversionCastOp nextCast = castOp;
+    while (nextCast) {
+      if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+        // Found a cast where the input types match the output types of the
+        // matched op. We can directly use those inputs and the matched op can
+        // be removed.
+        enqueueOperands(castOp);
+        castOp.replaceAllUsesWith(nextCast.getInputs());
+        erasedOps.insert(castOp.getOperation());
+        castOp->erase();
+        break;
+      }
+      nextCast = getInputCast(nextCast);
+    }
+  }
+
+  SmallVector<UnrealizedConversionCastOp> result;
+  for (UnrealizedConversionCastOp op : castOps)
+    if (!erasedOps.contains(op.getOperation()))
+      result.push_back(op);
+
+  return result;
+}
+
+static LogicalResult
+legalizeUnresolveMaterialization(RewriterBase &rewriter,
+                                 UnrealizedConversionCastOp op,
+                                 UnresolvedMaterializationMetadata metadata) {
+  assert(!op.use_empty() &&
+         "expected that dead materializations have already been DCE'd");
+  Operation::operand_range inputOperands = op.getOperands();
+  Type outputType = op.getResultTypes()[0];
+
+  // Try to materialize the conversion.
+  if (const TypeConverter *converter = metadata.converter) {
+    rewriter.setInsertionPoint(op);
+    Value newMaterialization;
+    switch (metadata.kind) {
+    case MaterializationKind::Argument:
+      // Try to materialize an argument conversion.
+      newMaterialization = converter->materializeArgumentConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      if (newMaterialization)
+        break;
+      // If an argument materialization failed, fallback to trying a target
+      // materialization.
+      [[fallthrough]];
+    case MaterializationKind::Target:
+      newMaterialization = converter->materializeTargetConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
+    case MaterializationKind::Source:
+      newMaterialization = converter->materializeSourceConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
+    }
+    if (newMaterialization) {
+      assert(newMaterialization.getType() == outputType &&
+             "materialization callback produced value of incorrect type");
+      rewriter.replaceOp(op, newMaterialization);
+      return success();
+    }
+  }
+
+  InFlightDiagnostic diag = op->emitError()
+                            << "failed to legalize unresolved materialization "
+                               "from ("
+                            << inputOperands.getTypes() << ") to " << outputType
+                            << " that remained live after conversion";
+  diag.attachNote(op->getUsers().begin()->getLoc())
+      << "see existing live user here: " << *op->getUsers().begin();
+  return failure();
+}
+
 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   if (ops.empty())
     return success();
@@ -2447,6 +2581,16 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   } else {
     rewriterImpl.applyRewrites();
   }
+
+  IRRewriter re2(rewriterImpl.context);
+  SmallVector<UnrealizedConversionCastOp> castOps =
+      reconcileUnrealizedCasts(rewriterImpl.getUnresolvedMaterializations());
+  for (UnrealizedConversionCastOp op : castOps)
+    if (failed(legalizeUnresolveMaterialization(
+            re2, op,
+            rewriterImpl.unresolvedMaterializationMetadata[op.getOperation()])))
+      return failure();
+
   return success();
 }
 
@@ -2460,9 +2604,6 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
   if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
                                             inverseMapping)))
     return failure();
-  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
-                                                inverseMapping)))
-    return failure();
   return success();
 }
 
@@ -2577,279 +2718,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
   return success();
 }
 
-/// Replace the results of a materialization operation with the given values.
-static void
-replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
-                       ResultRange matResults, ValueRange values,
-                       DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  matResults.replaceAllUsesWith(values);
-
-  // For each of the materialization results, update the inverse mappings to
-  // point to the replacement values.
-  for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
-    auto inverseMapIt = inverseMapping.find(matResult);
-    if (inverseMapIt == inverseMapping.end())
-      continue;
-
-    // Update the reverse mapping, or remove the mapping if we couldn't update
-    // it. Not being able to update signals that the mapping would have become
-    // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
-    // propagated through temporary materializations. We simply drop the
-    // mapping, and let the post-conversion replacement logic handle updating
-    // uses.
-    for (Value inverseMapVal : inverseMapIt->second)
-      if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
-        rewriterImpl.mapping.erase(inverseMapVal);
-  }
-}
-
-/// Compute all of the unresolved materializations that will persist beyond the
-/// conversion process, and require inserting a proper user materialization for.
-static void computeNecessaryMaterializations(
-    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
-        &materializationOps,
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping,
-    SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
-  // Helper function to check if the given value or a not yet materialized
-  // replacement of the given value is live.
-  // Note: `inverseMapping` maps from replaced values to original values.
-  auto isLive = [&](Value value) {
-    auto findFn = [&](Operation *user) {
-      auto matIt = materializationOps.find(user);
-      if (matIt != materializationOps.end())
-        return !necessaryMaterializations.count(matIt->second);
-      return rewriterImpl.isOpIgnored(user);
-    };
-    // A worklist is needed because a value may have gone through a chain of
-    // replacements and each of the replaced values may have live users.
-    SmallVector<Value> worklist;
-    worklist.push_back(value);
-    while (!worklist.empty()) {
-      Value next = worklist.pop_back_val();
-      if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
-        return true;
-      // This value may be replacing another value that has a live user.
-      llvm::append_range(worklist, inverseMapping.lookup(next));
-    }
-    return false;
-  };
-
-  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
-      [&](Value invalidRoot, Value value, Type type) {
-        // Check to see if the input operation was remapped to a variant of the
-        // output.
-        Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
-        if (remappedValue.getType() == type && remappedValue != invalidRoot)
-          return remappedValue;
-
-        // Check to see if the input is a materialization operation that
-        // provides an inverse conversion. We just check blindly for
-        // UnrealizedConversionCastOp here, but it has no effect on correctness.
-        auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
-        if (inputCastOp && inputCastOp->getNumOperands() == 1)
-          return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
-                                     type);
-
-        return Value();
-      };
-
-  SetVector<UnresolvedMaterializationRewrite *> worklist;
-  for (auto &rewrite : rewriterImpl.rewrites) {
-    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
-    if (!mat)
-      continue;
-    materializationOps.try_emplace(mat->getOperation(), mat);
-    worklist.insert(mat);
-  }
-  while (!worklist.empty()) {
-    UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
-    UnrealizedConversionCastOp op = mat->getOperation();
-
-    // We currently only handle target materializations here.
-    assert(op->getNumResults() == 1 && "unexpected materialization type");
-    OpResult opResult = op->getOpResult(0);
-    Type outputType = opResult.getType();
-    Operation::operand_range inputOperands = op.getOperands();
-
-    // Try to forward propagate operands for user conversion casts that result
-    // in the input types of the current cast.
-    for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
-      auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
-      if (!castOp)
-        continue;
-      if (castOp->getResultTypes() == inputOperands.getTypes()) {
-        replaceMaterialization(rewriterImpl, user->getResults(), inputOperands,
-                               inverseMapping);
-        necessaryMaterializations.remove(materializationOps.lookup(user));
-      }
-    }
-
-    // Try to avoid materializing a resolved materialization if possible.
-    // Handle the case of a 1-1 materialization.
-    if (inputOperands.size() == 1) {
-      // Check to see if the input operation was remapped to a variant of the
-      // output.
-      Value remappedValue =
-          lookupRemappedValue(opResult, inputOperands[0], outputType);
-      if (remappedValue && remappedValue != opResult) {
-        replaceMaterialization(rewriterImpl, opResult, remappedValue,
-                               inverseMapping);
-        necessaryMaterializations.remove(mat);
-        continue;
-      }
-    } else {
-      // TODO: Avoid materializing other types of conversions here.
-    }
-
-    // If the materialization does not have any live users, we don't need to
-    // generate a user materialization for it.
-    bool isMaterializationLive = isLive(opResult);
-    if (!isMaterializationLive)
-      continue;
-    if (!necessaryMaterializations.insert(mat))
-      continue;
-
-    // Reprocess input materializations to see if they have an updated status.
-    for (Value input : inputOperands) {
-      if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
-        if (auto *mat = materializationOps.lookup(parentOp))
-          worklist.insert(mat);
-      }
-    }
-  }
-}
-
-/// Legalize the given unresolved materialization. Returns success if the
-/// materialization was legalized, failure otherise.
-static LogicalResult legalizeUnresolvedMaterialization(
-    UnresolvedMaterializationRewrite &mat,
-    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
-        &materializationOps,
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  auto findLiveUser = [&](auto &&users) {
-    auto liveUserIt = llvm::find_if_not(
-        users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
-    return liveUserIt == users.end() ? nullptr : *liveUserIt;
-  };
-
-  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
-      [&](Value value, Type type) {
-        // Check to see if the input operation was remapped to a variant of the
-        // output.
-        Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
-        if (remappedValue.getType() == type)
-          return remappedValue;
-        return Value();
-      };
-
-  UnrealizedConversionCastOp op = mat.getOperation();
-  if (!rewriterImpl.ignoredOps.insert(op))
-    return success();
-
-  // We currently only handle target materializations here.
-  OpResult opResult = op->getOpResult(0);
-  Operation::operand_range inputOperands = op.getOperands();
-  Type outputType = opResult.getType();
-
-  // If any input to this materialization is another materialization, resolve
-  // the input first.
-  for (Value value : op->getOperands()) {
-    auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
-    if (!valueCast)
-      continue;
-
-    auto matIt = materializationOps.find(valueCast);
-    if (matIt != materializationOps.end())
-      if (failed(legalizeUnresolvedMaterialization(
-              *matIt->second, materializationOps, rewriter, rewriterImpl,
-              inverseMapping)))
-        return failure();
-  }
-
-  // Perform a last ditch attempt to avoid materializing a resolved
-  // materialization if possible.
-  // Handle the case of a 1-1 materialization.
-  if (inputOperands.size() == 1) {
-    // Check to see if the input operation was remapped to a variant of the
-    // output.
-    Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
-    if (remappedValue && remappedValue != opResult) {
-      replaceMaterialization(rewriterImpl, opResult, remappedValue,
-                             inverseMapping);
-      return success();
-    }
-  } else {
-    // TODO: Avoid materializing other types of conversions here.
-  }
-
-  // Try to materialize the conversion.
-  if (const TypeConverter *converter = mat.getConverter()) {
-    rewriter.setInsertionPoint(op);
-    Value newMaterialization;
-    switch (mat.getMaterializationKind()) {
-    case MaterializationKind::Argument:
-      // Try to materialize an argument conversion.
-      newMaterialization = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      if (newMaterialization)
-        break;
-      // If an argument materialization failed, fallback to trying a target
-      // materialization.
-      [[fallthrough]];
-    case MaterializationKind::Target:
-      newMaterialization = converter->materializeTargetConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      break;
-    case MaterializationKind::Source:
-      newMaterialization = converter->materializeSourceConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      break;
-    }
-    if (newMaterialization) {
-      assert(newMaterialization.getType() == outputType &&
-             "materialization callback produced value of incorrect type");
-      replaceMaterialization(rewriterImpl, opResult, newMaterialization,
-                             inverseMapping);
-      return success();
-    }
-  }
-
-  InFlightDiagnostic diag = op->emitError()
-                            << "failed to legalize unresolved materialization "
-                               "from ("
-                            << inputOperands.getTypes() << ") to " << outputType
-                            << " that remained live after conversion";
-  if (Operation *liveUser = findLiveUser(op->getUsers())) {
-    diag.attachNote(liveUser->getLoc())
-        << "see existing live user here: " << *liveUser;
-  }
-  return failure();
-}
-
-LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  // As an initial step, compute all of the inserted materializations that we
-  // expect to persist beyond the conversion process.
-  DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps;
-  SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
-  computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
-                                   inverseMapping, necessaryMaterializations);
-
-  // Once computed, legalize any necessary materializations.
-  for (auto *mat : necessaryMaterializations) {
-    if (failed(legalizeUnresolvedMaterialization(
-            *mat, materializationOps, rewriter, rewriterImpl, inverseMapping)))
-      return failure();
-  }
-  return success();
-}
-
 LogicalResult OperationConverter::legalizeErasedResult(
     Operation *op, OpResult result,
     ConversionPatternRewriterImpl &rewriterImpl) {
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 156a8a468d5b42..75362378daaaaa 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1286,7 +1286,6 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
 
 // CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
 // CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK-DAG: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[S3:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
 // CHECK: %[[S4:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: %[[S5:.+]] = llvm.extractvalue %[[S4]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
@@ -1299,8 +1298,8 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
 // CHECK: %[[S136:.+]] = llvm.insertvalue %[[S134]], %[[S135]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
 // CHECK: nvvm.wgmma.fence.aligned
 // CHECK: %[[S137:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
-// CHECK: %[[S138:.+]] = llvm.extractvalue %136[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
-// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %[[S0]], %1, %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S138:.+]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
+// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
 // CHECK: nvvm.wgmma.mma_async
 // CHECK: nvvm.wgmma.mma_async
 // CHECK: %[[S154:.+]] = nvvm.wgmma.mma_async
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index a192434c5accf8..ab18ce05e355d3 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -80,6 +80,7 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
   %0 = bufferization.to_tensor %m : memref<?xf32>
   // expected-error @+1 {{failed to legalize unresolved materialization from ('memref<?xf32>') to 'memref<?xf32, strided<[1], offset: ?>>' that remained live after conversion}}
   %1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+  // expected-note @below{{see existing live user here}}
   return %1 : memref<?xf32, strided<[1], offset: ?>>
 }
 
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index cf2c9f6a8ec441..f130adff42f8cd 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -4,6 +4,7 @@
 func.func @test_invalid_arg_materialization(
   // expected-error at below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
   %arg0: i16) {
+  // expected-note at below{{see existing live user here}}
   "foo.return"(%arg0) : (i16) -> ()
 }
 
@@ -22,6 +23,7 @@ func.func @test_valid_arg_materialization(%arg0: i64) {
 func.func @test_invalid_result_materialization() {
   // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
   %result = "test.type_producer"() : () -> f16
+  // expected-note at below{{see existing live user here}}
   "foo.return"(%result) : (f16) -> ()
 }
 
@@ -30,6 +32,7 @@ func.func @test_invalid_result_materialization() {
 func.func @test_invalid_result_materialization() {
   // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
   %result = "test.type_producer"() : () -> f16
+  // expected-note at below{{see existing live user here}}
   "foo.return"(%result) : (f16) -> ()
 }
 
@@ -49,6 +52,7 @@ func.func @test_transitive_use_materialization() {
 func.func @test_transitive_use_invalid_materialization() {
   // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
   %result = "test.another_type_producer"() : () -> f16
+  // expected-note at below{{see existing live user here}}
   "foo.return"(%result) : (f16) -> ()
 }
 
@@ -99,9 +103,9 @@ func.func @test_block_argument_not_converted() {
 func.func @test_signature_conversion_no_converter() {
   "test.signature_conversion_no_converter"() ({
   // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
-  // expected-note at below {{see existing live user here}}
   ^bb0(%arg0: f32):
     "test.type_consumer"(%arg0) : (f32) -> ()
+    // expected-note at below{{see existing live user here}}
     "test.return"(%arg0) : (f32) -> ()
   }) : () -> ()
   return



More information about the Mlir-commits mailing list