[Mlir-commits] [mlir] [mlir][Transforms][NFC] Dialect conversion: Cache `UnresolvedMaterializationRewrite` (PR #108359)
Matthias Springer
llvmlistbot at llvm.org
Fri Sep 13 10:55:26 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/108359
>From e724e4491effdfad44d0c1331acd83b8aa75a33b Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 12 Sep 2024 12:45:44 +0200
Subject: [PATCH] [mlir][Transforms][NFC] Dialect conversion: Cache
`UnresolvedMaterializationRewrite`
The dialect conversion already maintains a set of unresolved materializations (`UnrealizedConversionCastOp`). Turn that set into a map that maps from ops to `UnresolvedMaterializationRewrite *`. This improves efficiency a bit, because an iteration over `ConversionPatternRewriterImpl::rewrites` can be avoided.
Also delete some dead code.
---
.../Transforms/Utils/DialectConversion.cpp | 72 +++++++------------
1 file changed, 27 insertions(+), 45 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b58a95c3baf70a..caea9e111afeda 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
- MaterializationKind kind = MaterializationKind::Target)
- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind) {}
+ MaterializationKind kind = MaterializationKind::Target);
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
});
}
-/// Find the single rewrite object of the specified type and block among the
-/// given rewrites. In debug mode, asserts that there is mo more than one such
-/// object. Return "nullptr" if no object was found.
-template <typename RewriteTy, typename R>
-static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
- RewriteTy *result = nullptr;
- for (auto &rewrite : rewrites) {
- auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
- if (rewriteTy && rewriteTy->getBlock() == block) {
-#ifndef NDEBUG
- assert(!result && "expected single matching rewrite");
- result = rewriteTy;
-#else
- return rewriteTy;
-#endif // NDEBUG
- }
- }
- return result;
-}
-
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
@@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasErased(void *ptr) const { return erased.contains(ptr); }
- bool wasErased(OperationRewrite *rewrite) const {
- return wasErased(rewrite->getOperation());
- }
-
void notifyOperationErased(Operation *op) override { erased.insert(op); }
void notifyBlockErased(Block *block) override { erased.insert(block); }
@@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// to modify/access them is invalid rewriter API usage.
SetVector<Operation *> replacedOps;
- /// A set of all unresolved materializations.
- DenseSet<Operation *> unresolvedMaterializations;
+ /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
+ /// to the corresponding rewrite objects.
+ DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+ unresolvedMaterializations;
/// The current type converter, or nullptr if no type converter is currently
/// active.
@@ -1058,12 +1034,20 @@ void CreateOperationRewrite::rollback() {
op->erase();
}
+UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
+ const TypeConverter *converter, MaterializationKind kind)
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+ converterAndKind(converter, kind) {
+ rewriterImpl.unresolvedMaterializations[op] = this;
+}
+
void UnresolvedMaterializationRewrite::rollback() {
if (getMaterializationKind() == MaterializationKind::Target) {
for (Value input : op->getOperands())
rewriterImpl.mapping.erase(input);
}
- rewriterImpl.unresolvedMaterializations.erase(op);
+ rewriterImpl.unresolvedMaterializations.erase(getOperation());
op->erase();
}
@@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
- unresolvedMaterializations.insert(convertOp);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
@@ -1382,10 +1365,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
if (!newValue) {
// This result was dropped and no replacement value was provided.
- if (unresolvedMaterializations.contains(op)) {
- // Do not create another materializations if we are erasing a
- // materialization.
- continue;
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+ if (unresolvedMaterializations.contains(castOp)) {
+ // Do not create another materializations if we are erasing a
+ // materialization.
+ continue;
+ }
}
// Materialize a replacement value "out of thin air".
@@ -2499,15 +2484,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Gather all unresolved materializations.
SmallVector<UnrealizedConversionCastOp> allCastOps;
- DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
- for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) {
- auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
- if (!mat)
- continue;
- if (rewriterImpl.eraseRewriter.wasErased(mat))
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+ &materializations = rewriterImpl.unresolvedMaterializations;
+ for (auto it : materializations) {
+ if (rewriterImpl.eraseRewriter.wasErased(it.first))
continue;
- allCastOps.push_back(mat->getOperation());
- rewriteMap[mat->getOperation()] = mat;
+ allCastOps.push_back(it.first);
}
// Reconcile all UnrealizedConversionCastOps that were inserted by the
@@ -2520,8 +2502,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (config.buildMaterializations) {
IRRewriter rewriter(rewriterImpl.context, config.listener);
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
- auto it = rewriteMap.find(castOp.getOperation());
- assert(it != rewriteMap.end() && "inconsistent state");
+ auto it = materializations.find(castOp);
+ assert(it != materializations.end() && "inconsistent state");
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
return failure();
}
More information about the Mlir-commits
mailing list