[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Make materializations optional (PR #104668)

Matthias Springer llvmlistbot at llvm.org
Fri Aug 23 08:48:26 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/104668

>From 41949c94c651158abefa174a19724a0994f73a00 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 17 Aug 2024 11:38:40 +0200
Subject: [PATCH] [mlir][Transforms][WIP] Dialect conversion: Make
 materializations optional

Build all source/target/argument materializations after the conversion has succeeded. Provide a new configuration option for users to opt out of all automatic materializations. In that case, the resulting IR will have `builtin.unrealized_conversion_cast` ops.
---
 .../mlir/Transforms/DialectConversion.h       |  11 +
 .../Transforms/Utils/DialectConversion.cpp    | 383 ++++--------------
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir |   5 +-
 .../Transforms/finalizing-bufferize.mlir      |   1 +
 .../test-legalize-type-conversion.mlir        |   6 +-
 5 files changed, 108 insertions(+), 298 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 60113bdef16a23..5f680e8eca7559 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1124,6 +1124,17 @@ struct ConversionConfig {
   // already been modified) and iterators into past IR state cannot be
   // represented at the moment.
   RewriterBase::Listener *listener = nullptr;
+
+  /// If set to "true", the dialect conversion attempts to build source/target/
+  /// argument materializations through the type converter API in lieu of
+  /// builtin.unrealized_conversion_cast ops. The conversion process fails if
+  /// at least one materialization could not be built.
+  ///
+  /// If set to "false", the dialect conversion does not does not build any
+  /// custom materializations and instead inserts
+  /// builtin.unrealized_conversion_cast ops to ensure that the resulting IR
+  /// is valid.
+  bool buildMaterializations = true;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index adf012a261cb7e..c6d7c4691b1ccd 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -702,14 +702,8 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
   }
 
-  UnrealizedConversionCastOp getOperation() const {
-    return cast<UnrealizedConversionCastOp>(op);
-  }
-
   void rollback() override;
 
-  void cleanup(RewriterBase &rewriter) override;
-
   /// Return the type converter of this materialization (which may be null).
   const TypeConverter *getConverter() const {
     return converterAndKind.getPointer();
@@ -766,7 +760,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : context(ctx), config(config) {}
+      : context(ctx), eraseRewriter(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -834,6 +828,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   //===--------------------------------------------------------------------===//
   // Materializations
   //===--------------------------------------------------------------------===//
+
   /// Build an unresolved materialization operation given an output type and set
   /// of input operands.
   Value buildUnresolvedMaterialization(MaterializationKind kind,
@@ -912,6 +907,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// MLIR context.
   MLIRContext *context;
 
+  /// A rewriter that keeps track of ops/block that were already erased and
+  /// skips duplicate op/block erasures. This rewriter is used during the
+  /// "cleanup" phase.
+  SingleEraseRewriter eraseRewriter;
+
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
@@ -1058,10 +1058,6 @@ void UnresolvedMaterializationRewrite::rollback() {
   op->erase();
 }
 
-void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
-  rewriter.eraseOp(op);
-}
-
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
   IRRewriter rewriter(context, config.listener);
@@ -1069,7 +1065,6 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrite->commit(rewriter);
 
   // Clean up all rewrites.
-  SingleEraseRewriter eraseRewriter(context);
   for (auto &rewrite : rewrites)
     rewrite->cleanup(eraseRewriter);
 }
@@ -2354,12 +2349,6 @@ struct OperationConverter {
       ConversionPatternRewriterImpl &rewriterImpl,
       DenseMap<Value, SmallVector<Value>> &inverseMapping);
 
-  /// Legalize any unresolved type materializations.
-  LogicalResult legalizeUnresolvedMaterializations(
-      ConversionPatternRewriter &rewriter,
-      ConversionPatternRewriterImpl &rewriterImpl,
-      DenseMap<Value, SmallVector<Value>> &inverseMapping);
-
   /// Legalize an operation result that was marked as "erased".
   LogicalResult
   legalizeErasedResult(Operation *op, OpResult result,
@@ -2406,6 +2395,57 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
   return success();
 }
 
+static LogicalResult
+legalizeUnresolvedMaterialization(RewriterBase &rewriter,
+                                  UnresolvedMaterializationRewrite *rewrite) {
+  UnrealizedConversionCastOp op =
+      cast<UnrealizedConversionCastOp>(rewrite->getOperation());
+  assert(!op.use_empty() &&
+         "expected that dead materializations have already been DCE'd");
+  Operation::operand_range inputOperands = op.getOperands();
+  Type outputType = op.getResultTypes()[0];
+
+  // Try to materialize the conversion.
+  if (const TypeConverter *converter = rewrite->getConverter()) {
+    rewriter.setInsertionPoint(op);
+    Value newMaterialization;
+    switch (rewrite->getMaterializationKind()) {
+    case MaterializationKind::Argument:
+      // Try to materialize an argument conversion.
+      newMaterialization = converter->materializeArgumentConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      if (newMaterialization)
+        break;
+      // If an argument materialization failed, fallback to trying a target
+      // materialization.
+      [[fallthrough]];
+    case MaterializationKind::Target:
+      newMaterialization = converter->materializeTargetConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
+    case MaterializationKind::Source:
+      newMaterialization = converter->materializeSourceConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
+    }
+    if (newMaterialization) {
+      assert(newMaterialization.getType() == outputType &&
+             "materialization callback produced value of incorrect type");
+      rewriter.replaceOp(op, newMaterialization);
+      return success();
+    }
+  }
+
+  InFlightDiagnostic diag = op->emitError()
+                            << "failed to legalize unresolved materialization "
+                               "from ("
+                            << inputOperands.getTypes() << ") to " << outputType
+                            << " that remained live after conversion";
+  diag.attachNote(op->getUsers().begin()->getLoc())
+      << "see existing live user here: " << *op->getUsers().begin();
+  return failure();
+}
+
 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   if (ops.empty())
     return success();
@@ -2447,6 +2487,37 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   } else {
     rewriterImpl.applyRewrites();
   }
+
+  // Gather all unresolved materializations.
+  SmallVector<UnrealizedConversionCastOp> allCastOps;
+  DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
+  for (auto &rewrite : rewriterImpl.rewrites) {
+    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
+    if (!mat)
+      continue;
+    if (rewriterImpl.eraseRewriter.erased.contains(mat->getOperation()))
+      continue;
+    allCastOps.push_back(cast<UnrealizedConversionCastOp>(mat->getOperation()));
+    rewriteMap[mat->getOperation()] = mat;
+  }
+
+  // Reconcile all UnrealizedConversionCastOps that were inserted by the
+  // dialect conversion frameworks. (Not the one that were inserted by
+  // patterns.)
+  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
+  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
+
+  // Try to legalize all unresolved materializations.
+  if (config.buildMaterializations) {
+    IRRewriter rewriter(rewriterImpl.context, config.listener);
+    for (UnrealizedConversionCastOp castOp : remainingCastOps) {
+      auto it = rewriteMap.find(castOp.getOperation());
+      assert(it != rewriteMap.end() && "inconsistent state");
+      if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
+        return failure();
+    }
+  }
+
   return success();
 }
 
@@ -2460,9 +2531,6 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
   if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
                                             inverseMapping)))
     return failure();
-  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
-                                                inverseMapping)))
-    return failure();
   return success();
 }
 
@@ -2578,279 +2646,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
   return success();
 }
 
-/// Replace the results of a materialization operation with the given values.
-static void
-replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
-                       ResultRange matResults, ValueRange values,
-                       DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  matResults.replaceAllUsesWith(values);
-
-  // For each of the materialization results, update the inverse mappings to
-  // point to the replacement values.
-  for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
-    auto inverseMapIt = inverseMapping.find(matResult);
-    if (inverseMapIt == inverseMapping.end())
-      continue;
-
-    // Update the reverse mapping, or remove the mapping if we couldn't update
-    // it. Not being able to update signals that the mapping would have become
-    // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
-    // propagated through temporary materializations. We simply drop the
-    // mapping, and let the post-conversion replacement logic handle updating
-    // uses.
-    for (Value inverseMapVal : inverseMapIt->second)
-      if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
-        rewriterImpl.mapping.erase(inverseMapVal);
-  }
-}
-
-/// Compute all of the unresolved materializations that will persist beyond the
-/// conversion process, and require inserting a proper user materialization for.
-static void computeNecessaryMaterializations(
-    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
-        &materializationOps,
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping,
-    SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
-  // Helper function to check if the given value or a not yet materialized
-  // replacement of the given value is live.
-  // Note: `inverseMapping` maps from replaced values to original values.
-  auto isLive = [&](Value value) {
-    auto findFn = [&](Operation *user) {
-      auto matIt = materializationOps.find(user);
-      if (matIt != materializationOps.end())
-        return !necessaryMaterializations.count(matIt->second);
-      return rewriterImpl.isOpIgnored(user);
-    };
-    // A worklist is needed because a value may have gone through a chain of
-    // replacements and each of the replaced values may have live users.
-    SmallVector<Value> worklist;
-    worklist.push_back(value);
-    while (!worklist.empty()) {
-      Value next = worklist.pop_back_val();
-      if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
-        return true;
-      // This value may be replacing another value that has a live user.
-      llvm::append_range(worklist, inverseMapping.lookup(next));
-    }
-    return false;
-  };
-
-  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
-      [&](Value invalidRoot, Value value, Type type) {
-        // Check to see if the input operation was remapped to a variant of the
-        // output.
-        Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
-        if (remappedValue.getType() == type && remappedValue != invalidRoot)
-          return remappedValue;
-
-        // Check to see if the input is a materialization operation that
-        // provides an inverse conversion. We just check blindly for
-        // UnrealizedConversionCastOp here, but it has no effect on correctness.
-        auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
-        if (inputCastOp && inputCastOp->getNumOperands() == 1)
-          return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
-                                     type);
-
-        return Value();
-      };
-
-  SetVector<UnresolvedMaterializationRewrite *> worklist;
-  for (auto &rewrite : rewriterImpl.rewrites) {
-    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
-    if (!mat)
-      continue;
-    materializationOps.try_emplace(mat->getOperation(), mat);
-    worklist.insert(mat);
-  }
-  while (!worklist.empty()) {
-    UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
-    UnrealizedConversionCastOp op = mat->getOperation();
-
-    // We currently only handle target materializations here.
-    assert(op->getNumResults() == 1 && "unexpected materialization type");
-    OpResult opResult = op->getOpResult(0);
-    Type outputType = opResult.getType();
-    Operation::operand_range inputOperands = op.getOperands();
-
-    // Try to forward propagate operands for user conversion casts that result
-    // in the input types of the current cast.
-    for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
-      auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
-      if (!castOp)
-        continue;
-      if (castOp->getResultTypes() == inputOperands.getTypes()) {
-        replaceMaterialization(rewriterImpl, user->getResults(), inputOperands,
-                               inverseMapping);
-        necessaryMaterializations.remove(materializationOps.lookup(user));
-      }
-    }
-
-    // Try to avoid materializing a resolved materialization if possible.
-    // Handle the case of a 1-1 materialization.
-    if (inputOperands.size() == 1) {
-      // Check to see if the input operation was remapped to a variant of the
-      // output.
-      Value remappedValue =
-          lookupRemappedValue(opResult, inputOperands[0], outputType);
-      if (remappedValue && remappedValue != opResult) {
-        replaceMaterialization(rewriterImpl, opResult, remappedValue,
-                               inverseMapping);
-        necessaryMaterializations.remove(mat);
-        continue;
-      }
-    } else {
-      // TODO: Avoid materializing other types of conversions here.
-    }
-
-    // If the materialization does not have any live users, we don't need to
-    // generate a user materialization for it.
-    bool isMaterializationLive = isLive(opResult);
-    if (!isMaterializationLive)
-      continue;
-    if (!necessaryMaterializations.insert(mat))
-      continue;
-
-    // Reprocess input materializations to see if they have an updated status.
-    for (Value input : inputOperands) {
-      if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
-        if (auto *mat = materializationOps.lookup(parentOp))
-          worklist.insert(mat);
-      }
-    }
-  }
-}
-
-/// Legalize the given unresolved materialization. Returns success if the
-/// materialization was legalized, failure otherise.
-static LogicalResult legalizeUnresolvedMaterialization(
-    UnresolvedMaterializationRewrite &mat,
-    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
-        &materializationOps,
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  auto findLiveUser = [&](auto &&users) {
-    auto liveUserIt = llvm::find_if_not(
-        users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
-    return liveUserIt == users.end() ? nullptr : *liveUserIt;
-  };
-
-  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
-      [&](Value value, Type type) {
-        // Check to see if the input operation was remapped to a variant of the
-        // output.
-        Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
-        if (remappedValue.getType() == type)
-          return remappedValue;
-        return Value();
-      };
-
-  UnrealizedConversionCastOp op = mat.getOperation();
-  if (!rewriterImpl.ignoredOps.insert(op))
-    return success();
-
-  // We currently only handle target materializations here.
-  OpResult opResult = op->getOpResult(0);
-  Operation::operand_range inputOperands = op.getOperands();
-  Type outputType = opResult.getType();
-
-  // If any input to this materialization is another materialization, resolve
-  // the input first.
-  for (Value value : op->getOperands()) {
-    auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
-    if (!valueCast)
-      continue;
-
-    auto matIt = materializationOps.find(valueCast);
-    if (matIt != materializationOps.end())
-      if (failed(legalizeUnresolvedMaterialization(
-              *matIt->second, materializationOps, rewriter, rewriterImpl,
-              inverseMapping)))
-        return failure();
-  }
-
-  // Perform a last ditch attempt to avoid materializing a resolved
-  // materialization if possible.
-  // Handle the case of a 1-1 materialization.
-  if (inputOperands.size() == 1) {
-    // Check to see if the input operation was remapped to a variant of the
-    // output.
-    Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
-    if (remappedValue && remappedValue != opResult) {
-      replaceMaterialization(rewriterImpl, opResult, remappedValue,
-                             inverseMapping);
-      return success();
-    }
-  } else {
-    // TODO: Avoid materializing other types of conversions here.
-  }
-
-  // Try to materialize the conversion.
-  if (const TypeConverter *converter = mat.getConverter()) {
-    rewriter.setInsertionPoint(op);
-    Value newMaterialization;
-    switch (mat.getMaterializationKind()) {
-    case MaterializationKind::Argument:
-      // Try to materialize an argument conversion.
-      newMaterialization = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      if (newMaterialization)
-        break;
-      // If an argument materialization failed, fallback to trying a target
-      // materialization.
-      [[fallthrough]];
-    case MaterializationKind::Target:
-      newMaterialization = converter->materializeTargetConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      break;
-    case MaterializationKind::Source:
-      newMaterialization = converter->materializeSourceConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      break;
-    }
-    if (newMaterialization) {
-      assert(newMaterialization.getType() == outputType &&
-             "materialization callback produced value of incorrect type");
-      replaceMaterialization(rewriterImpl, opResult, newMaterialization,
-                             inverseMapping);
-      return success();
-    }
-  }
-
-  InFlightDiagnostic diag = op->emitError()
-                            << "failed to legalize unresolved materialization "
-                               "from ("
-                            << inputOperands.getTypes() << ") to " << outputType
-                            << " that remained live after conversion";
-  if (Operation *liveUser = findLiveUser(op->getUsers())) {
-    diag.attachNote(liveUser->getLoc())
-        << "see existing live user here: " << *liveUser;
-  }
-  return failure();
-}
-
-LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  // As an initial step, compute all of the inserted materializations that we
-  // expect to persist beyond the conversion process.
-  DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps;
-  SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
-  computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
-                                   inverseMapping, necessaryMaterializations);
-
-  // Once computed, legalize any necessary materializations.
-  for (auto *mat : necessaryMaterializations) {
-    if (failed(legalizeUnresolvedMaterialization(
-            *mat, materializationOps, rewriter, rewriterImpl, inverseMapping)))
-      return failure();
-  }
-  return success();
-}
-
 LogicalResult OperationConverter::legalizeErasedResult(
     Operation *op, OpResult result,
     ConversionPatternRewriterImpl &rewriterImpl) {
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 156a8a468d5b42..75362378daaaaa 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1286,7 +1286,6 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
 
 // CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
 // CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK-DAG: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[S3:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
 // CHECK: %[[S4:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: %[[S5:.+]] = llvm.extractvalue %[[S4]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
@@ -1299,8 +1298,8 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
 // CHECK: %[[S136:.+]] = llvm.insertvalue %[[S134]], %[[S135]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
 // CHECK: nvvm.wgmma.fence.aligned
 // CHECK: %[[S137:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
-// CHECK: %[[S138:.+]] = llvm.extractvalue %136[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
-// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %[[S0]], %1, %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S138:.+]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
+// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
 // CHECK: nvvm.wgmma.mma_async
 // CHECK: nvvm.wgmma.mma_async
 // CHECK: %[[S154:.+]] = nvvm.wgmma.mma_async
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index a192434c5accf8..ab18ce05e355d3 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -80,6 +80,7 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
   %0 = bufferization.to_tensor %m : memref<?xf32>
   // expected-error @+1 {{failed to legalize unresolved materialization from ('memref<?xf32>') to 'memref<?xf32, strided<[1], offset: ?>>' that remained live after conversion}}
   %1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+  // expected-note @below{{see existing live user here}}
   return %1 : memref<?xf32, strided<[1], offset: ?>>
 }
 
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index cf2c9f6a8ec441..f130adff42f8cd 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -4,6 +4,7 @@
 func.func @test_invalid_arg_materialization(
   // expected-error at below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
   %arg0: i16) {
+  // expected-note at below{{see existing live user here}}
   "foo.return"(%arg0) : (i16) -> ()
 }
 
@@ -22,6 +23,7 @@ func.func @test_valid_arg_materialization(%arg0: i64) {
 func.func @test_invalid_result_materialization() {
   // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
   %result = "test.type_producer"() : () -> f16
+  // expected-note at below{{see existing live user here}}
   "foo.return"(%result) : (f16) -> ()
 }
 
@@ -30,6 +32,7 @@ func.func @test_invalid_result_materialization() {
 func.func @test_invalid_result_materialization() {
   // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
   %result = "test.type_producer"() : () -> f16
+  // expected-note at below{{see existing live user here}}
   "foo.return"(%result) : (f16) -> ()
 }
 
@@ -49,6 +52,7 @@ func.func @test_transitive_use_materialization() {
 func.func @test_transitive_use_invalid_materialization() {
   // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
   %result = "test.another_type_producer"() : () -> f16
+  // expected-note at below{{see existing live user here}}
   "foo.return"(%result) : (f16) -> ()
 }
 
@@ -99,9 +103,9 @@ func.func @test_block_argument_not_converted() {
 func.func @test_signature_conversion_no_converter() {
   "test.signature_conversion_no_converter"() ({
   // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
-  // expected-note at below {{see existing live user here}}
   ^bb0(%arg0: f32):
     "test.type_consumer"(%arg0) : (f32) -> ()
+    // expected-note at below{{see existing live user here}}
     "test.return"(%arg0) : (f32) -> ()
   }) : () -> ()
   return



More information about the Mlir-commits mailing list