[Mlir-commits] [mlir] [mlir][Transforms][WIP] Dialect conversion: Make materializations optional (PR #104668)
Matthias Springer
llvmlistbot at llvm.org
Sat Aug 17 02:40:50 PDT 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/104668
Build all source/target/argument materializations after the conversion has succeeded. Provide a new configuration option for users to opt out of all automatic materializations. In that case, the resulting IR will have `builtin.unrealized_conversion_cast` ops.
Draft only, do not merge.
>From 23c0a47fd8e98689d66776c7b102ff0d2ccd91c5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 17 Aug 2024 11:38:40 +0200
Subject: [PATCH] [mlir][Transforms][WIP] Dialect conversion: Make
materializations optional
Build all source/target/argument materializations after the conversion has succeeded. Provide a new configuration option for users to opt out of all automatic materializations. In that case, the resulting IR will have `builtin.unrealized_conversion_cast` ops.
---
.../Transforms/Utils/DialectConversion.cpp | 474 +++++++-----------
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 5 +-
.../Transforms/finalizing-bufferize.mlir | 1 +
.../test-legalize-type-conversion.mlir | 6 +-
4 files changed, 179 insertions(+), 307 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8a4c7463a69a95..56eec3e8b00c16 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -686,6 +686,16 @@ enum MaterializationKind {
Source
};
+struct UnresolvedMaterializationMetadata {
+ UnresolvedMaterializationMetadata() = default;
+ UnresolvedMaterializationMetadata(const TypeConverter *converter,
+ MaterializationKind kind)
+ : converter(converter), kind(kind) {}
+
+ const TypeConverter *converter;
+ MaterializationKind kind;
+};
+
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
/// op. Unresolved materializations are erased at the end of the dialect
/// conversion.
@@ -696,7 +706,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
MaterializationKind kind = MaterializationKind::Target)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind) {}
+ metadata(converter, kind) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,23 +718,12 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
void rollback() override;
- void cleanup(RewriterBase &rewriter) override;
-
- /// Return the type converter of this materialization (which may be null).
- const TypeConverter *getConverter() const {
- return converterAndKind.getPointer();
- }
-
- /// Return the kind of this materialization.
- MaterializationKind getMaterializationKind() const {
- return converterAndKind.getInt();
- }
+ UnresolvedMaterializationMetadata getMetadata() const { return metadata; }
private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
- llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
- converterAndKind;
+ UnresolvedMaterializationMetadata metadata;
};
} // namespace
@@ -834,6 +833,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
//===--------------------------------------------------------------------===//
// Materializations
//===--------------------------------------------------------------------===//
+
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
Value buildUnresolvedMaterialization(MaterializationKind kind,
@@ -841,6 +841,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ValueRange inputs, Type outputType,
const TypeConverter *converter);
+ SmallVector<UnrealizedConversionCastOp> getUnresolvedMaterializations();
+
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -877,8 +879,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// no new IR is created between calls to `eraseOp`/`eraseBlock`.
struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
public:
- SingleEraseRewriter(MLIRContext *context)
- : RewriterBase(context, /*listener=*/this) {}
+ SingleEraseRewriter(MLIRContext *ctx)
+ : RewriterBase(ctx, /*listener=*/this) {}
/// Erase the given op (unless it was already erased).
void eraseOp(Operation *op) override {
@@ -930,6 +932,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// to modify/access them is invalid rewriter API usage.
SetVector<Operation *> replacedOps;
+ SmallVector<Operation *> unresolvedMaterializations;
+ DenseSet<Operation *> erasedUnresolvedMaterializations;
+ DenseMap<Operation *, UnresolvedMaterializationMetadata>
+ unresolvedMaterializationMetadata;
+
/// The current type converter, or nullptr if no type converter is currently
/// active.
const TypeConverter *currentTypeConverter = nullptr;
@@ -1051,17 +1058,17 @@ void CreateOperationRewrite::rollback() {
}
void UnresolvedMaterializationRewrite::rollback() {
- if (getMaterializationKind() == MaterializationKind::Target) {
+ auto it = llvm::find(rewriterImpl.unresolvedMaterializations, op);
+ assert(it != rewriterImpl.unresolvedMaterializations.end() &&
+ "inconsistent state");
+ rewriterImpl.unresolvedMaterializations.erase(it);
+ if (metadata.kind == MaterializationKind::Target) {
for (Value input : op->getOperands())
rewriterImpl.mapping.erase(input);
}
op->erase();
}
-void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
- rewriter.eraseOp(op);
-}
-
void ConversionPatternRewriterImpl::applyRewrites() {
// Commit all rewrites.
IRRewriter rewriter(context, config.listener);
@@ -1072,6 +1079,9 @@ void ConversionPatternRewriterImpl::applyRewrites() {
SingleEraseRewriter eraseRewriter(context);
for (auto &rewrite : rewrites)
rewrite->cleanup(eraseRewriter);
+
+ for (void *op : eraseRewriter.erased)
+ erasedUnresolvedMaterializations.insert(reinterpret_cast<Operation *>(op));
}
//===----------------------------------------------------------------------===//
@@ -1342,10 +1352,22 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+ unresolvedMaterializations.push_back(convertOp);
+ unresolvedMaterializationMetadata[convertOp] =
+ UnresolvedMaterializationMetadata(converter, kind);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
+SmallVector<UnrealizedConversionCastOp>
+ConversionPatternRewriterImpl::getUnresolvedMaterializations() {
+ SmallVector<UnrealizedConversionCastOp> result;
+ for (Operation *op : unresolvedMaterializations)
+ if (!erasedUnresolvedMaterializations.contains(op))
+ result.push_back(cast<UnrealizedConversionCastOp>(op));
+ return result;
+}
+
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -2354,12 +2376,6 @@ struct OperationConverter {
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping);
- /// Legalize any unresolved type materializations.
- LogicalResult legalizeUnresolvedMaterializations(
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping);
-
/// Legalize an operation result that was marked as "erased".
LogicalResult
legalizeErasedResult(Operation *op, OpResult result,
@@ -2406,6 +2422,124 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
return success();
}
+static SmallVector<UnrealizedConversionCastOp>
+reconcileUnrealizedCasts(ArrayRef<UnrealizedConversionCastOp> castOps) {
+ SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
+ castOps.end());
+ DenseSet<Operation *> erasedOps;
+
+ // Helper function that adds all operands to the worklist that are an
+ // unrealized_conversion_cast op result.
+ auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+ for (Value v : castOp.getInputs())
+ if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+ worklist.insert(inputCastOp);
+ };
+
+ // Helper function that return the unrealized_conversion_cast op that
+ // defines all inputs of the given op (in the same order). Return "nullptr"
+ // if there is no such op.
+ auto getInputCast =
+ [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+ if (castOp.getInputs().empty())
+ return {};
+ auto inputCastOp =
+ castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
+ if (!inputCastOp)
+ return {};
+ if (inputCastOp.getOutputs() != castOp.getInputs())
+ return {};
+ return inputCastOp;
+ };
+
+ // Process ops in the worklist bottom-to-top.
+ while (!worklist.empty()) {
+ UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+ if (castOp->use_empty()) {
+ // DCE: If the op has no users, erase it. Add the operands to the
+ // worklist to find additional DCE opportunities.
+ enqueueOperands(castOp);
+ erasedOps.insert(castOp.getOperation());
+ castOp->erase();
+ continue;
+ }
+
+ // Traverse the chain of input cast ops to see if an op with the same
+ // input types can be found.
+ UnrealizedConversionCastOp nextCast = castOp;
+ while (nextCast) {
+ if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ // Found a cast where the input types match the output types of the
+ // matched op. We can directly use those inputs and the matched op can
+ // be removed.
+ enqueueOperands(castOp);
+ castOp.replaceAllUsesWith(nextCast.getInputs());
+ erasedOps.insert(castOp.getOperation());
+ castOp->erase();
+ break;
+ }
+ nextCast = getInputCast(nextCast);
+ }
+ }
+
+ SmallVector<UnrealizedConversionCastOp> result;
+ for (UnrealizedConversionCastOp op : castOps)
+ if (!erasedOps.contains(op.getOperation()))
+ result.push_back(op);
+
+ return result;
+}
+
+static LogicalResult
+legalizeUnresolveMaterialization(RewriterBase &rewriter,
+ UnrealizedConversionCastOp op,
+ UnresolvedMaterializationMetadata metadata) {
+ assert(!op.use_empty() &&
+ "expected that dead materializations have already been DCE'd");
+ Operation::operand_range inputOperands = op.getOperands();
+ Type outputType = op.getResultTypes()[0];
+
+ // Try to materialize the conversion.
+ if (const TypeConverter *converter = metadata.converter) {
+ rewriter.setInsertionPoint(op);
+ Value newMaterialization;
+ switch (metadata.kind) {
+ case MaterializationKind::Argument:
+ // Try to materialize an argument conversion.
+ newMaterialization = converter->materializeArgumentConversion(
+ rewriter, op->getLoc(), outputType, inputOperands);
+ if (newMaterialization)
+ break;
+ // If an argument materialization failed, fallback to trying a target
+ // materialization.
+ [[fallthrough]];
+ case MaterializationKind::Target:
+ newMaterialization = converter->materializeTargetConversion(
+ rewriter, op->getLoc(), outputType, inputOperands);
+ break;
+ case MaterializationKind::Source:
+ newMaterialization = converter->materializeSourceConversion(
+ rewriter, op->getLoc(), outputType, inputOperands);
+ break;
+ }
+ if (newMaterialization) {
+ assert(newMaterialization.getType() == outputType &&
+ "materialization callback produced value of incorrect type");
+ rewriter.replaceOp(op, newMaterialization);
+ return success();
+ }
+ }
+
+ InFlightDiagnostic diag = op->emitError()
+ << "failed to legalize unresolved materialization "
+ "from ("
+ << inputOperands.getTypes() << ") to " << outputType
+ << " that remained live after conversion";
+ diag.attachNote(op->getUsers().begin()->getLoc())
+ << "see existing live user here: " << *op->getUsers().begin();
+ return failure();
+}
+
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (ops.empty())
return success();
@@ -2447,6 +2581,16 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
} else {
rewriterImpl.applyRewrites();
}
+
+ IRRewriter re2(rewriterImpl.context);
+ SmallVector<UnrealizedConversionCastOp> castOps =
+ reconcileUnrealizedCasts(rewriterImpl.getUnresolvedMaterializations());
+ for (UnrealizedConversionCastOp op : castOps)
+ if (failed(legalizeUnresolveMaterialization(
+ re2, op,
+ rewriterImpl.unresolvedMaterializationMetadata[op.getOperation()])))
+ return failure();
+
return success();
}
@@ -2460,9 +2604,6 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
inverseMapping)))
return failure();
- if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
- inverseMapping)))
- return failure();
return success();
}
@@ -2577,279 +2718,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
return success();
}
-/// Replace the results of a materialization operation with the given values.
-static void
-replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
- ResultRange matResults, ValueRange values,
- DenseMap<Value, SmallVector<Value>> &inverseMapping) {
- matResults.replaceAllUsesWith(values);
-
- // For each of the materialization results, update the inverse mappings to
- // point to the replacement values.
- for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
- auto inverseMapIt = inverseMapping.find(matResult);
- if (inverseMapIt == inverseMapping.end())
- continue;
-
- // Update the reverse mapping, or remove the mapping if we couldn't update
- // it. Not being able to update signals that the mapping would have become
- // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
- // propagated through temporary materializations. We simply drop the
- // mapping, and let the post-conversion replacement logic handle updating
- // uses.
- for (Value inverseMapVal : inverseMapIt->second)
- if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
- rewriterImpl.mapping.erase(inverseMapVal);
- }
-}
-
-/// Compute all of the unresolved materializations that will persist beyond the
-/// conversion process, and require inserting a proper user materialization for.
-static void computeNecessaryMaterializations(
- DenseMap<Operation *, UnresolvedMaterializationRewrite *>
- &materializationOps,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping,
- SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
- // Helper function to check if the given value or a not yet materialized
- // replacement of the given value is live.
- // Note: `inverseMapping` maps from replaced values to original values.
- auto isLive = [&](Value value) {
- auto findFn = [&](Operation *user) {
- auto matIt = materializationOps.find(user);
- if (matIt != materializationOps.end())
- return !necessaryMaterializations.count(matIt->second);
- return rewriterImpl.isOpIgnored(user);
- };
- // A worklist is needed because a value may have gone through a chain of
- // replacements and each of the replaced values may have live users.
- SmallVector<Value> worklist;
- worklist.push_back(value);
- while (!worklist.empty()) {
- Value next = worklist.pop_back_val();
- if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
- return true;
- // This value may be replacing another value that has a live user.
- llvm::append_range(worklist, inverseMapping.lookup(next));
- }
- return false;
- };
-
- llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
- [&](Value invalidRoot, Value value, Type type) {
- // Check to see if the input operation was remapped to a variant of the
- // output.
- Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
- if (remappedValue.getType() == type && remappedValue != invalidRoot)
- return remappedValue;
-
- // Check to see if the input is a materialization operation that
- // provides an inverse conversion. We just check blindly for
- // UnrealizedConversionCastOp here, but it has no effect on correctness.
- auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
- if (inputCastOp && inputCastOp->getNumOperands() == 1)
- return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
- type);
-
- return Value();
- };
-
- SetVector<UnresolvedMaterializationRewrite *> worklist;
- for (auto &rewrite : rewriterImpl.rewrites) {
- auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
- if (!mat)
- continue;
- materializationOps.try_emplace(mat->getOperation(), mat);
- worklist.insert(mat);
- }
- while (!worklist.empty()) {
- UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
- UnrealizedConversionCastOp op = mat->getOperation();
-
- // We currently only handle target materializations here.
- assert(op->getNumResults() == 1 && "unexpected materialization type");
- OpResult opResult = op->getOpResult(0);
- Type outputType = opResult.getType();
- Operation::operand_range inputOperands = op.getOperands();
-
- // Try to forward propagate operands for user conversion casts that result
- // in the input types of the current cast.
- for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
- auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
- if (!castOp)
- continue;
- if (castOp->getResultTypes() == inputOperands.getTypes()) {
- replaceMaterialization(rewriterImpl, user->getResults(), inputOperands,
- inverseMapping);
- necessaryMaterializations.remove(materializationOps.lookup(user));
- }
- }
-
- // Try to avoid materializing a resolved materialization if possible.
- // Handle the case of a 1-1 materialization.
- if (inputOperands.size() == 1) {
- // Check to see if the input operation was remapped to a variant of the
- // output.
- Value remappedValue =
- lookupRemappedValue(opResult, inputOperands[0], outputType);
- if (remappedValue && remappedValue != opResult) {
- replaceMaterialization(rewriterImpl, opResult, remappedValue,
- inverseMapping);
- necessaryMaterializations.remove(mat);
- continue;
- }
- } else {
- // TODO: Avoid materializing other types of conversions here.
- }
-
- // If the materialization does not have any live users, we don't need to
- // generate a user materialization for it.
- bool isMaterializationLive = isLive(opResult);
- if (!isMaterializationLive)
- continue;
- if (!necessaryMaterializations.insert(mat))
- continue;
-
- // Reprocess input materializations to see if they have an updated status.
- for (Value input : inputOperands) {
- if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
- if (auto *mat = materializationOps.lookup(parentOp))
- worklist.insert(mat);
- }
- }
- }
-}
-
-/// Legalize the given unresolved materialization. Returns success if the
-/// materialization was legalized, failure otherise.
-static LogicalResult legalizeUnresolvedMaterialization(
- UnresolvedMaterializationRewrite &mat,
- DenseMap<Operation *, UnresolvedMaterializationRewrite *>
- &materializationOps,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping) {
- auto findLiveUser = [&](auto &&users) {
- auto liveUserIt = llvm::find_if_not(
- users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
- return liveUserIt == users.end() ? nullptr : *liveUserIt;
- };
-
- llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
- [&](Value value, Type type) {
- // Check to see if the input operation was remapped to a variant of the
- // output.
- Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
- if (remappedValue.getType() == type)
- return remappedValue;
- return Value();
- };
-
- UnrealizedConversionCastOp op = mat.getOperation();
- if (!rewriterImpl.ignoredOps.insert(op))
- return success();
-
- // We currently only handle target materializations here.
- OpResult opResult = op->getOpResult(0);
- Operation::operand_range inputOperands = op.getOperands();
- Type outputType = opResult.getType();
-
- // If any input to this materialization is another materialization, resolve
- // the input first.
- for (Value value : op->getOperands()) {
- auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
- if (!valueCast)
- continue;
-
- auto matIt = materializationOps.find(valueCast);
- if (matIt != materializationOps.end())
- if (failed(legalizeUnresolvedMaterialization(
- *matIt->second, materializationOps, rewriter, rewriterImpl,
- inverseMapping)))
- return failure();
- }
-
- // Perform a last ditch attempt to avoid materializing a resolved
- // materialization if possible.
- // Handle the case of a 1-1 materialization.
- if (inputOperands.size() == 1) {
- // Check to see if the input operation was remapped to a variant of the
- // output.
- Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
- if (remappedValue && remappedValue != opResult) {
- replaceMaterialization(rewriterImpl, opResult, remappedValue,
- inverseMapping);
- return success();
- }
- } else {
- // TODO: Avoid materializing other types of conversions here.
- }
-
- // Try to materialize the conversion.
- if (const TypeConverter *converter = mat.getConverter()) {
- rewriter.setInsertionPoint(op);
- Value newMaterialization;
- switch (mat.getMaterializationKind()) {
- case MaterializationKind::Argument:
- // Try to materialize an argument conversion.
- newMaterialization = converter->materializeArgumentConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
- if (newMaterialization)
- break;
- // If an argument materialization failed, fallback to trying a target
- // materialization.
- [[fallthrough]];
- case MaterializationKind::Target:
- newMaterialization = converter->materializeTargetConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
- break;
- case MaterializationKind::Source:
- newMaterialization = converter->materializeSourceConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
- break;
- }
- if (newMaterialization) {
- assert(newMaterialization.getType() == outputType &&
- "materialization callback produced value of incorrect type");
- replaceMaterialization(rewriterImpl, opResult, newMaterialization,
- inverseMapping);
- return success();
- }
- }
-
- InFlightDiagnostic diag = op->emitError()
- << "failed to legalize unresolved materialization "
- "from ("
- << inputOperands.getTypes() << ") to " << outputType
- << " that remained live after conversion";
- if (Operation *liveUser = findLiveUser(op->getUsers())) {
- diag.attachNote(liveUser->getLoc())
- << "see existing live user here: " << *liveUser;
- }
- return failure();
-}
-
-LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping) {
- // As an initial step, compute all of the inserted materializations that we
- // expect to persist beyond the conversion process.
- DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps;
- SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
- computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
- inverseMapping, necessaryMaterializations);
-
- // Once computed, legalize any necessary materializations.
- for (auto *mat : necessaryMaterializations) {
- if (failed(legalizeUnresolvedMaterialization(
- *mat, materializationOps, rewriter, rewriterImpl, inverseMapping)))
- return failure();
- }
- return success();
-}
-
LogicalResult OperationConverter::legalizeErasedResult(
Operation *op, OpResult result,
ConversionPatternRewriterImpl &rewriterImpl) {
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 156a8a468d5b42..75362378daaaaa 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1286,7 +1286,6 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
// CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
// CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK-DAG: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: %[[S4:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S5:.+]] = llvm.extractvalue %[[S4]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
@@ -1299,8 +1298,8 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
// CHECK: %[[S136:.+]] = llvm.insertvalue %[[S134]], %[[S135]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: nvvm.wgmma.fence.aligned
// CHECK: %[[S137:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
-// CHECK: %[[S138:.+]] = llvm.extractvalue %136[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
-// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %[[S0]], %1, %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S138:.+]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvvm.wgmma.mma_async
// CHECK: nvvm.wgmma.mma_async
// CHECK: %[[S154:.+]] = nvvm.wgmma.mma_async
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index a192434c5accf8..ab18ce05e355d3 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -80,6 +80,7 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
%0 = bufferization.to_tensor %m : memref<?xf32>
// expected-error @+1 {{failed to legalize unresolved materialization from ('memref<?xf32>') to 'memref<?xf32, strided<[1], offset: ?>>' that remained live after conversion}}
%1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+ // expected-note @below{{see existing live user here}}
return %1 : memref<?xf32, strided<[1], offset: ?>>
}
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index cf2c9f6a8ec441..f130adff42f8cd 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -4,6 +4,7 @@
func.func @test_invalid_arg_materialization(
// expected-error at below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
%arg0: i16) {
+ // expected-note at below{{see existing live user here}}
"foo.return"(%arg0) : (i16) -> ()
}
@@ -22,6 +23,7 @@ func.func @test_valid_arg_materialization(%arg0: i64) {
func.func @test_invalid_result_materialization() {
// expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
%result = "test.type_producer"() : () -> f16
+ // expected-note at below{{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}
@@ -30,6 +32,7 @@ func.func @test_invalid_result_materialization() {
func.func @test_invalid_result_materialization() {
// expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
%result = "test.type_producer"() : () -> f16
+ // expected-note at below{{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}
@@ -49,6 +52,7 @@ func.func @test_transitive_use_materialization() {
func.func @test_transitive_use_invalid_materialization() {
// expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
%result = "test.another_type_producer"() : () -> f16
+ // expected-note at below{{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}
@@ -99,9 +103,9 @@ func.func @test_block_argument_not_converted() {
func.func @test_signature_conversion_no_converter() {
"test.signature_conversion_no_converter"() ({
// expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
- // expected-note at below {{see existing live user here}}
^bb0(%arg0: f32):
"test.type_consumer"(%arg0) : (f32) -> ()
+ // expected-note at below{{see existing live user here}}
"test.return"(%arg0) : (f32) -> ()
}) : () -> ()
return
More information about the Mlir-commits
mailing list