[Mlir-commits] [mlir] 3815f47 - [mlir][Transforms] Dialect conversion: Make materializations optional (#107109)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 5 10:41:03 PDT 2024
Author: Matthias Springer
Date: 2024-09-05T19:40:58+02:00
New Revision: 3815f478bb4f1c724d36044a4e0bbd3352313322
URL: https://github.com/llvm/llvm-project/commit/3815f478bb4f1c724d36044a4e0bbd3352313322
DIFF: https://github.com/llvm/llvm-project/commit/3815f478bb4f1c724d36044a4e0bbd3352313322.diff
LOG: [mlir][Transforms] Dialect conversion: Make materializations optional (#107109)
This commit makes source/target/argument materializations (via the
`TypeConverter` API) optional.
By default (`ConversionConfig::buildMaterializations = true`), the
dialect conversion infrastructure tries to legalize all unresolved
materializations right after the main transformation process has
succeeded. If at least one unresolved materialization fails to resolve,
the dialect conversion fails. (With an error message such as `failed to
legalize unresolved materialization ...`.) Automatic materializations
through the `TypeConverter` API can now be deactivated. In that case,
every unresolved materialization will show up as a
`builtin.unrealized_conversion_cast` op in the output IR.
There used to be a complex and error-prone analysis in the dialect
conversion that predicted the future uses of unresolved
materializations. Based on that logic, some casts (that were deemed to
unnecessary) were folded. This analysis was needed because folding
happened at a point of time when some IR changes (e.g., op replacements)
had not materialized yet.
This commit removes that analysis. Any folding of cast ops now happens
after all other IR changes have been materialized and the uses can
directly be queried from the IR. This simplifies the analysis
significantly. And certain helper data structures such as
`inverseMapping` are no longer needed for the analysis. The folding
itself is done by `reconcileUnrealizedCasts` (which also exists as a
standalone pass).
After casts have been folded, the remaining casts are materialized
through the `TypeConverter`, as usual. This last step can be deactivated
in the `ConversionConfig`.
`ConversionConfig::buildMaterializations = false` can be used to debug
error messages such as `failed to legalize unresolved materialization
...`. (It is also useful in case automatic materializations are not
needed.) The materializations that failed to resolve can then be seen as
`builtin.unrealized_conversion_cast` ops in the resulting IR. (This is
better than running with `-debug`, because `-debug` shows IR where some
IR changes have not been materialized yet.)
Note: This is a reupload of #104668, but with correct handling of cyclic
unrealized_conversion_casts that may be generated by the dialect
conversion.
Added:
Modified:
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
mlir/test/Transforms/test-legalize-type-conversion.mlir
mlir/test/Transforms/test-legalizer.mlir
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 60113bdef16a23..5f680e8eca7559 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1124,6 +1124,17 @@ struct ConversionConfig {
// already been modified) and iterators into past IR state cannot be
// represented at the moment.
RewriterBase::Listener *listener = nullptr;
+
+ /// If set to "true", the dialect conversion attempts to build source/target/
+ /// argument materializations through the type converter API in lieu of
+ /// builtin.unrealized_conversion_cast ops. The conversion process fails if
+ /// at least one materialization could not be built.
+ ///
+ /// If set to "false", the dialect conversion does not does not build any
+ /// custom materializations and instead inserts
+ /// builtin.unrealized_conversion_cast ops to ensure that the resulting IR
+ /// is valid.
+ bool buildMaterializations = true;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b23fb97959ed67..450e66f0db4e74 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -702,14 +702,12 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
}
+ void rollback() override;
+
UnrealizedConversionCastOp getOperation() const {
return cast<UnrealizedConversionCastOp>(op);
}
- void rollback() override;
-
- void cleanup(RewriterBase &rewriter) override;
-
/// Return the type converter of this materialization (which may be null).
const TypeConverter *getConverter() const {
return converterAndKind.getPointer();
@@ -766,7 +764,7 @@ namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
const ConversionConfig &config)
- : context(ctx), config(config) {}
+ : context(ctx), eraseRewriter(ctx), config(config) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -834,6 +832,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
//===--------------------------------------------------------------------===//
// Materializations
//===--------------------------------------------------------------------===//
+
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
Value buildUnresolvedMaterialization(MaterializationKind kind,
@@ -882,7 +881,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Erase the given op (unless it was already erased).
void eraseOp(Operation *op) override {
- if (erased.contains(op))
+ if (wasErased(op))
return;
op->dropAllUses();
RewriterBase::eraseOp(op);
@@ -890,17 +889,24 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Erase the given block (unless it was already erased).
void eraseBlock(Block *block) override {
- if (erased.contains(block))
+ if (wasErased(block))
return;
assert(block->empty() && "expected empty block");
block->dropAllDefinedValueUses();
RewriterBase::eraseBlock(block);
}
+ bool wasErased(void *ptr) const { return erased.contains(ptr); }
+
+ bool wasErased(OperationRewrite *rewrite) const {
+ return wasErased(rewrite->getOperation());
+ }
+
void notifyOperationErased(Operation *op) override { erased.insert(op); }
void notifyBlockErased(Block *block) override { erased.insert(block); }
+ private:
/// Pointers to all erased operations and blocks.
DenseSet<void *> erased;
};
@@ -912,6 +918,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// MLIR context.
MLIRContext *context;
+ /// A rewriter that keeps track of ops/block that were already erased and
+ /// skips duplicate op/block erasures. This rewriter is used during the
+ /// "cleanup" phase.
+ SingleEraseRewriter eraseRewriter;
+
// Mapping between replaced values that
diff er in type. This happens when
// replacing a value with one of a
diff erent type.
ConversionValueMapping mapping;
@@ -1058,10 +1069,6 @@ void UnresolvedMaterializationRewrite::rollback() {
op->erase();
}
-void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
- rewriter.eraseOp(op);
-}
-
void ConversionPatternRewriterImpl::applyRewrites() {
// Commit all rewrites.
IRRewriter rewriter(context, config.listener);
@@ -1069,7 +1076,6 @@ void ConversionPatternRewriterImpl::applyRewrites() {
rewrite->commit(rewriter);
// Clean up all rewrites.
- SingleEraseRewriter eraseRewriter(context);
for (auto &rewrite : rewrites)
rewrite->cleanup(eraseRewriter);
}
@@ -2353,12 +2359,6 @@ struct OperationConverter {
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping);
- /// Legalize any unresolved type materializations.
- LogicalResult legalizeUnresolvedMaterializations(
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping);
-
/// Legalize an operation result that was marked as "erased".
LogicalResult
legalizeErasedResult(Operation *op, OpResult result,
@@ -2405,6 +2405,128 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
return success();
}
+static LogicalResult
+legalizeUnresolvedMaterialization(RewriterBase &rewriter,
+ UnresolvedMaterializationRewrite *rewrite) {
+ UnrealizedConversionCastOp op = rewrite->getOperation();
+ assert(!op.use_empty() &&
+ "expected that dead materializations have already been DCE'd");
+ Operation::operand_range inputOperands = op.getOperands();
+ Type outputType = op.getResultTypes()[0];
+
+ // Try to materialize the conversion.
+ if (const TypeConverter *converter = rewrite->getConverter()) {
+ rewriter.setInsertionPoint(op);
+ Value newMaterialization;
+ switch (rewrite->getMaterializationKind()) {
+ case MaterializationKind::Argument:
+ // Try to materialize an argument conversion.
+ newMaterialization = converter->materializeArgumentConversion(
+ rewriter, op->getLoc(), outputType, inputOperands);
+ if (newMaterialization)
+ break;
+ // If an argument materialization failed, fallback to trying a target
+ // materialization.
+ [[fallthrough]];
+ case MaterializationKind::Target:
+ newMaterialization = converter->materializeTargetConversion(
+ rewriter, op->getLoc(), outputType, inputOperands);
+ break;
+ case MaterializationKind::Source:
+ newMaterialization = converter->materializeSourceConversion(
+ rewriter, op->getLoc(), outputType, inputOperands);
+ break;
+ }
+ if (newMaterialization) {
+ assert(newMaterialization.getType() == outputType &&
+ "materialization callback produced value of incorrect type");
+ rewriter.replaceOp(op, newMaterialization);
+ return success();
+ }
+ }
+
+ InFlightDiagnostic diag = op->emitError()
+ << "failed to legalize unresolved materialization "
+ "from ("
+ << inputOperands.getTypes() << ") to " << outputType
+ << " that remained live after conversion";
+ diag.attachNote(op->getUsers().begin()->getLoc())
+ << "see existing live user here: " << *op->getUsers().begin();
+ return failure();
+}
+
+/// Erase all dead unrealized_conversion_cast ops. An op is dead if its results
+/// are not used (transitively) by any op that is not in the given list of
+/// cast ops.
+///
+/// In particular, this function erases cyclic casts that may be inserted
+/// during the dialect conversion process. E.g.:
+/// %0 = unrealized_conversion_cast(%1)
+/// %1 = unrealized_conversion_cast(%0)
+// Note: This step will become unnecessary when
+// https://github.com/llvm/llvm-project/pull/106760 has been merged.
+static void eraseDeadUnrealizedCasts(
+ ArrayRef<UnrealizedConversionCastOp> castOps,
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+ // Ops that have already been visited or are currently being visited.
+ DenseSet<Operation *> visited;
+ // Set of all cast ops for faster lookups.
+ DenseSet<Operation *> castOpSet;
+ // Set of all cast ops that have been determined to be alive.
+ DenseSet<Operation *> live;
+
+ for (UnrealizedConversionCastOp op : castOps)
+ castOpSet.insert(op);
+
+ // Visit a cast operation. Return "true" if the operation is live.
+ std::function<bool(Operation *)> visit = [&](Operation *op) -> bool {
+ // No need to traverse any IR if the op was already marked as live.
+ if (live.contains(op))
+ return true;
+
+ // Do not visit ops multiple times. If we find a circle, no live user was
+ // found on the current path.
+ if (visited.contains(op))
+ return false;
+ visited.insert(op);
+
+ // Visit all users.
+ for (Operation *user : op->getUsers()) {
+ // If the user is not an unrealized_conversion_cast op, then the given op
+ // is live.
+ if (!castOpSet.contains(user)) {
+ live.insert(op);
+ return true;
+ }
+ // Otherwise, it is live if a live op can be reached from one of its
+ // users (which must all be unrealized_conversion_cast ops).
+ if (visit(user)) {
+ live.insert(op);
+ return true;
+ }
+ }
+
+ return false;
+ };
+
+ // Visit all cast ops.
+ for (UnrealizedConversionCastOp op : castOps) {
+ visit(op);
+ visited.clear();
+ }
+
+ // Erase all cast ops that are dead.
+ for (UnrealizedConversionCastOp op : castOps) {
+ if (live.contains(op)) {
+ if (remainingCastOps)
+ remainingCastOps->push_back(op);
+ continue;
+ }
+ op->dropAllUses();
+ op->erase();
+ }
+}
+
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (ops.empty())
return success();
@@ -2446,6 +2568,38 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
} else {
rewriterImpl.applyRewrites();
}
+
+ // Gather all unresolved materializations.
+ SmallVector<UnrealizedConversionCastOp> allCastOps;
+ DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
+ for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) {
+ auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
+ if (!mat)
+ continue;
+ if (rewriterImpl.eraseRewriter.wasErased(mat))
+ continue;
+ allCastOps.push_back(mat->getOperation());
+ rewriteMap[mat->getOperation()] = mat;
+ }
+
+ // Reconcile all UnrealizedConversionCastOps that were inserted by the
+ // dialect conversion frameworks. (Not the one that were inserted by
+ // patterns.)
+ SmallVector<UnrealizedConversionCastOp> remainingCastOps1, remainingCastOps2;
+ eraseDeadUnrealizedCasts(allCastOps, &remainingCastOps1);
+ reconcileUnrealizedCasts(remainingCastOps1, &remainingCastOps2);
+
+ // Try to legalize all unresolved materializations.
+ if (config.buildMaterializations) {
+ IRRewriter rewriter(rewriterImpl.context, config.listener);
+ for (UnrealizedConversionCastOp castOp : remainingCastOps2) {
+ auto it = rewriteMap.find(castOp.getOperation());
+ assert(it != rewriteMap.end() && "inconsistent state");
+ if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
+ return failure();
+ }
+ }
+
return success();
}
@@ -2459,9 +2613,6 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
inverseMapping)))
return failure();
- if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
- inverseMapping)))
- return failure();
return success();
}
@@ -2577,279 +2728,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
return success();
}
-/// Replace the results of a materialization operation with the given values.
-static void
-replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
- ResultRange matResults, ValueRange values,
- DenseMap<Value, SmallVector<Value>> &inverseMapping) {
- matResults.replaceAllUsesWith(values);
-
- // For each of the materialization results, update the inverse mappings to
- // point to the replacement values.
- for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
- auto inverseMapIt = inverseMapping.find(matResult);
- if (inverseMapIt == inverseMapping.end())
- continue;
-
- // Update the reverse mapping, or remove the mapping if we couldn't update
- // it. Not being able to update signals that the mapping would have become
- // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
- // propagated through temporary materializations. We simply drop the
- // mapping, and let the post-conversion replacement logic handle updating
- // uses.
- for (Value inverseMapVal : inverseMapIt->second)
- if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
- rewriterImpl.mapping.erase(inverseMapVal);
- }
-}
-
-/// Compute all of the unresolved materializations that will persist beyond the
-/// conversion process, and require inserting a proper user materialization for.
-static void computeNecessaryMaterializations(
- DenseMap<Operation *, UnresolvedMaterializationRewrite *>
- &materializationOps,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping,
- SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
- // Helper function to check if the given value or a not yet materialized
- // replacement of the given value is live.
- // Note: `inverseMapping` maps from replaced values to original values.
- auto isLive = [&](Value value) {
- auto findFn = [&](Operation *user) {
- auto matIt = materializationOps.find(user);
- if (matIt != materializationOps.end())
- return !necessaryMaterializations.count(matIt->second);
- return rewriterImpl.isOpIgnored(user);
- };
- // A worklist is needed because a value may have gone through a chain of
- // replacements and each of the replaced values may have live users.
- SmallVector<Value> worklist;
- worklist.push_back(value);
- while (!worklist.empty()) {
- Value next = worklist.pop_back_val();
- if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
- return true;
- // This value may be replacing another value that has a live user.
- llvm::append_range(worklist, inverseMapping.lookup(next));
- }
- return false;
- };
-
- llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
- [&](Value invalidRoot, Value value, Type type) {
- // Check to see if the input operation was remapped to a variant of the
- // output.
- Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
- if (remappedValue.getType() == type && remappedValue != invalidRoot)
- return remappedValue;
-
- // Check to see if the input is a materialization operation that
- // provides an inverse conversion. We just check blindly for
- // UnrealizedConversionCastOp here, but it has no effect on correctness.
- auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
- if (inputCastOp && inputCastOp->getNumOperands() == 1)
- return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
- type);
-
- return Value();
- };
-
- SetVector<UnresolvedMaterializationRewrite *> worklist;
- for (auto &rewrite : rewriterImpl.rewrites) {
- auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
- if (!mat)
- continue;
- materializationOps.try_emplace(mat->getOperation(), mat);
- worklist.insert(mat);
- }
- while (!worklist.empty()) {
- UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
- UnrealizedConversionCastOp op = mat->getOperation();
-
- // We currently only handle target materializations here.
- assert(op->getNumResults() == 1 && "unexpected materialization type");
- OpResult opResult = op->getOpResult(0);
- Type outputType = opResult.getType();
- Operation::operand_range inputOperands = op.getOperands();
-
- // Try to forward propagate operands for user conversion casts that result
- // in the input types of the current cast.
- for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
- auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
- if (!castOp)
- continue;
- if (castOp->getResultTypes() == inputOperands.getTypes()) {
- replaceMaterialization(rewriterImpl, user->getResults(), inputOperands,
- inverseMapping);
- necessaryMaterializations.remove(materializationOps.lookup(user));
- }
- }
-
- // Try to avoid materializing a resolved materialization if possible.
- // Handle the case of a 1-1 materialization.
- if (inputOperands.size() == 1) {
- // Check to see if the input operation was remapped to a variant of the
- // output.
- Value remappedValue =
- lookupRemappedValue(opResult, inputOperands[0], outputType);
- if (remappedValue && remappedValue != opResult) {
- replaceMaterialization(rewriterImpl, opResult, remappedValue,
- inverseMapping);
- necessaryMaterializations.remove(mat);
- continue;
- }
- } else {
- // TODO: Avoid materializing other types of conversions here.
- }
-
- // If the materialization does not have any live users, we don't need to
- // generate a user materialization for it.
- bool isMaterializationLive = isLive(opResult);
- if (!isMaterializationLive)
- continue;
- if (!necessaryMaterializations.insert(mat))
- continue;
-
- // Reprocess input materializations to see if they have an updated status.
- for (Value input : inputOperands) {
- if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
- if (auto *mat = materializationOps.lookup(parentOp))
- worklist.insert(mat);
- }
- }
- }
-}
-
-/// Legalize the given unresolved materialization. Returns success if the
-/// materialization was legalized, failure otherise.
-static LogicalResult legalizeUnresolvedMaterialization(
- UnresolvedMaterializationRewrite &mat,
- DenseMap<Operation *, UnresolvedMaterializationRewrite *>
- &materializationOps,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping) {
- auto findLiveUser = [&](auto &&users) {
- auto liveUserIt = llvm::find_if_not(
- users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
- return liveUserIt == users.end() ? nullptr : *liveUserIt;
- };
-
- llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
- [&](Value value, Type type) {
- // Check to see if the input operation was remapped to a variant of the
- // output.
- Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
- if (remappedValue.getType() == type)
- return remappedValue;
- return Value();
- };
-
- UnrealizedConversionCastOp op = mat.getOperation();
- if (!rewriterImpl.ignoredOps.insert(op))
- return success();
-
- // We currently only handle target materializations here.
- OpResult opResult = op->getOpResult(0);
- Operation::operand_range inputOperands = op.getOperands();
- Type outputType = opResult.getType();
-
- // If any input to this materialization is another materialization, resolve
- // the input first.
- for (Value value : op->getOperands()) {
- auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
- if (!valueCast)
- continue;
-
- auto matIt = materializationOps.find(valueCast);
- if (matIt != materializationOps.end())
- if (failed(legalizeUnresolvedMaterialization(
- *matIt->second, materializationOps, rewriter, rewriterImpl,
- inverseMapping)))
- return failure();
- }
-
- // Perform a last ditch attempt to avoid materializing a resolved
- // materialization if possible.
- // Handle the case of a 1-1 materialization.
- if (inputOperands.size() == 1) {
- // Check to see if the input operation was remapped to a variant of the
- // output.
- Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
- if (remappedValue && remappedValue != opResult) {
- replaceMaterialization(rewriterImpl, opResult, remappedValue,
- inverseMapping);
- return success();
- }
- } else {
- // TODO: Avoid materializing other types of conversions here.
- }
-
- // Try to materialize the conversion.
- if (const TypeConverter *converter = mat.getConverter()) {
- rewriter.setInsertionPoint(op);
- Value newMaterialization;
- switch (mat.getMaterializationKind()) {
- case MaterializationKind::Argument:
- // Try to materialize an argument conversion.
- newMaterialization = converter->materializeArgumentConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
- if (newMaterialization)
- break;
- // If an argument materialization failed, fallback to trying a target
- // materialization.
- [[fallthrough]];
- case MaterializationKind::Target:
- newMaterialization = converter->materializeTargetConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
- break;
- case MaterializationKind::Source:
- newMaterialization = converter->materializeSourceConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
- break;
- }
- if (newMaterialization) {
- assert(newMaterialization.getType() == outputType &&
- "materialization callback produced value of incorrect type");
- replaceMaterialization(rewriterImpl, opResult, newMaterialization,
- inverseMapping);
- return success();
- }
- }
-
- InFlightDiagnostic diag = op->emitError()
- << "failed to legalize unresolved materialization "
- "from ("
- << inputOperands.getTypes() << ") to " << outputType
- << " that remained live after conversion";
- if (Operation *liveUser = findLiveUser(op->getUsers())) {
- diag.attachNote(liveUser->getLoc())
- << "see existing live user here: " << *liveUser;
- }
- return failure();
-}
-
-LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping) {
- // As an initial step, compute all of the inserted materializations that we
- // expect to persist beyond the conversion process.
- DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps;
- SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
- computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
- inverseMapping, necessaryMaterializations);
-
- // Once computed, legalize any necessary materializations.
- for (auto *mat : necessaryMaterializations) {
- if (failed(legalizeUnresolvedMaterialization(
- *mat, materializationOps, rewriter, rewriterImpl, inverseMapping)))
- return failure();
- }
- return success();
-}
-
LogicalResult OperationConverter::legalizeErasedResult(
Operation *op, OpResult result,
ConversionPatternRewriterImpl &rewriterImpl) {
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 156a8a468d5b42..75362378daaaaa 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1286,7 +1286,6 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
// CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
// CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK-DAG: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: %[[S4:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S5:.+]] = llvm.extractvalue %[[S4]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
@@ -1299,8 +1298,8 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
// CHECK: %[[S136:.+]] = llvm.insertvalue %[[S134]], %[[S135]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: nvvm.wgmma.fence.aligned
// CHECK: %[[S137:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
-// CHECK: %[[S138:.+]] = llvm.extractvalue %136[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
-// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %[[S0]], %1, %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S138:.+]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvvm.wgmma.mma_async
// CHECK: nvvm.wgmma.mma_async
// CHECK: %[[S154:.+]] = nvvm.wgmma.mma_async
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index a192434c5accf8..ab18ce05e355d3 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -80,6 +80,7 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
%0 = bufferization.to_tensor %m : memref<?xf32>
// expected-error @+1 {{failed to legalize unresolved materialization from ('memref<?xf32>') to 'memref<?xf32, strided<[1], offset: ?>>' that remained live after conversion}}
%1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+ // expected-note @below{{see existing live user here}}
return %1 : memref<?xf32, strided<[1], offset: ?>>
}
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index cf2c9f6a8ec441..f130adff42f8cd 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -4,6 +4,7 @@
func.func @test_invalid_arg_materialization(
// expected-error at below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
%arg0: i16) {
+ // expected-note at below{{see existing live user here}}
"foo.return"(%arg0) : (i16) -> ()
}
@@ -22,6 +23,7 @@ func.func @test_valid_arg_materialization(%arg0: i64) {
func.func @test_invalid_result_materialization() {
// expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
%result = "test.type_producer"() : () -> f16
+ // expected-note at below{{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}
@@ -30,6 +32,7 @@ func.func @test_invalid_result_materialization() {
func.func @test_invalid_result_materialization() {
// expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
%result = "test.type_producer"() : () -> f16
+ // expected-note at below{{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}
@@ -49,6 +52,7 @@ func.func @test_transitive_use_materialization() {
func.func @test_transitive_use_invalid_materialization() {
// expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
%result = "test.another_type_producer"() : () -> f16
+ // expected-note at below{{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}
@@ -99,9 +103,9 @@ func.func @test_block_argument_not_converted() {
func.func @test_signature_conversion_no_converter() {
"test.signature_conversion_no_converter"() ({
// expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
- // expected-note at below {{see existing live user here}}
^bb0(%arg0: f32):
"test.type_consumer"(%arg0) : (f32) -> ()
+ // expected-note at below{{see existing live user here}}
"test.return"(%arg0) : (f32) -> ()
}) : () -> ()
return
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index a789ab9a82e192..e5503ee8920424 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -452,3 +452,14 @@ func.func @convert_detached_signature() {
}) : () -> ()
"test.return"() : () -> ()
}
+
+// -----
+
+// CHECK-LABEL: func @circular_mapping()
+// CHECK-NEXT: "test.valid"() : () -> ()
+func.func @circular_mapping() {
+ // Regression test that used to crash due to circular
+ // unrealized_conversion_cast ops.
+ %0 = "test.erase_op"() : () -> (i64)
+ "test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 91dfb2faa80a17..3cbc307835afd7 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -907,6 +907,22 @@ struct TestPassthroughInvalidOp : public ConversionPattern {
return success();
}
};
+/// Replace with valid op, but simply drop the operands. This is used in a
+/// regression where we used to generate circular unrealized_conversion_cast
+/// ops.
+struct TestDropAndReplaceInvalidOp : public ConversionPattern {
+ TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter,
+ "test.drop_operands_and_replace_with_valid", 1, ctx) {
+ }
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(),
+ std::nullopt);
+ return success();
+ }
+};
/// This pattern handles the case of a split return value.
struct TestSplitReturnType : public ConversionPattern {
TestSplitReturnType(MLIRContext *ctx)
@@ -1070,6 +1086,19 @@ struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
return success();
};
};
+
+class TestEraseOp : public ConversionPattern {
+public:
+ TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Erase op without replacements.
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -1148,8 +1177,9 @@ struct TestLegalizePatternDriver
TestUpdateConsumerType, TestNonRootReplacement,
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
- TestUndoPropertiesModification>(&getContext());
- patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
+ TestUndoPropertiesModification, TestEraseOp>(&getContext());
+ patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
+ &getContext(), converter);
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
mlir::populateCallOpTypeConversionPattern(patterns, converter);
More information about the Mlir-commits
mailing list